From 9c30aeacdbc92a7b93d369d424152acaf8470397 Mon Sep 17 00:00:00 2001 From: Skylot <118523+skylot@users.noreply.github.com> Date: Sun, 22 Sep 2024 20:10:37 +0100 Subject: [PATCH] refactor: split region maker --- .../visitors/regions/PostProcessRegions.java | 131 ++ .../dex/visitors/regions/RegionMaker.java | 1187 ----------------- .../visitors/regions/RegionMakerVisitor.java | 182 +-- .../regions/maker/ExcHandlersRegionMaker.java | 153 +++ .../IfRegionMaker.java} | 124 +- .../regions/maker/LoopRegionMaker.java | 464 +++++++ .../visitors/regions/maker/RegionMaker.java | 168 +++ .../regions/{ => maker}/RegionStack.java | 10 +- .../regions/maker/SwitchRegionMaker.java | 288 ++++ .../maker/SynchronizedRegionMaker.java | 162 +++ .../main/java/jadx/core/utils/BlockUtils.java | 44 + 11 files changed, 1548 insertions(+), 1365 deletions(-) create mode 100644 jadx-core/src/main/java/jadx/core/dex/visitors/regions/PostProcessRegions.java delete mode 100644 jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java create mode 100644 jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/ExcHandlersRegionMaker.java rename jadx-core/src/main/java/jadx/core/dex/visitors/regions/{IfMakerHelper.java => maker/IfRegionMaker.java} (77%) create mode 100644 jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/LoopRegionMaker.java create mode 100644 jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionMaker.java rename jadx-core/src/main/java/jadx/core/dex/visitors/regions/{ => maker}/RegionStack.java (93%) create mode 100644 jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java create mode 100644 jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SynchronizedRegionMaker.java diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/PostProcessRegions.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/PostProcessRegions.java new file mode 100644 index 000000000..0e045e8a1 --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/PostProcessRegions.java @@ -0,0 +1,131 @@ +package jadx.core.dex.visitors.regions; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import jadx.core.dex.attributes.AFlag; +import jadx.core.dex.attributes.AType; +import jadx.core.dex.attributes.nodes.EdgeInsnAttr; +import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.IBlock; +import jadx.core.dex.nodes.IContainer; +import jadx.core.dex.nodes.IRegion; +import jadx.core.dex.nodes.InsnContainer; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.regions.Region; +import jadx.core.dex.regions.SwitchRegion; +import jadx.core.dex.regions.loops.LoopRegion; +import jadx.core.utils.RegionUtils; + +final class PostProcessRegions extends AbstractRegionVisitor { + private static final Logger LOG = LoggerFactory.getLogger(PostProcessRegions.class); + + private static final IRegionVisitor INSTANCE = new PostProcessRegions(); + + static void process(MethodNode mth) { + DepthRegionTraversal.traverse(mth, INSTANCE); + } + + @Override + public void leaveRegion(MethodNode mth, IRegion region) { + if (region instanceof LoopRegion) { + // merge conditions in loops + LoopRegion loop = (LoopRegion) region; + loop.mergePreCondition(); + } else if (region instanceof SwitchRegion) { + // insert 'break' in switch cases (run after try/catch insertion) + processSwitch(mth, (SwitchRegion) region); + } else if (region instanceof Region) { + insertEdgeInsn((Region) region); + } + } + + /** + * Insert insn block from edge insn attribute. + */ + private static void insertEdgeInsn(Region region) { + List subBlocks = region.getSubBlocks(); + if (subBlocks.isEmpty()) { + return; + } + IContainer last = subBlocks.get(subBlocks.size() - 1); + List edgeInsnAttrs = last.getAll(AType.EDGE_INSN); + if (edgeInsnAttrs.isEmpty()) { + return; + } + EdgeInsnAttr insnAttr = edgeInsnAttrs.get(0); + if (!insnAttr.getStart().equals(last)) { + return; + } + if (last instanceof BlockNode) { + BlockNode block = (BlockNode) last; + if (block.getInstructions().isEmpty()) { + block.getInstructions().add(insnAttr.getInsn()); + return; + } + } + List insns = Collections.singletonList(insnAttr.getInsn()); + region.add(new InsnContainer(insns)); + } + + private static void processSwitch(MethodNode mth, SwitchRegion sw) { + for (IContainer c : sw.getBranches()) { + if (c instanceof Region) { + Set blocks = new HashSet<>(); + RegionUtils.getAllRegionBlocks(c, blocks); + if (blocks.isEmpty()) { + addBreakToContainer((Region) c); + } else { + for (IBlock block : blocks) { + if (block instanceof BlockNode) { + addBreakForBlock(mth, c, blocks, (BlockNode) block); + } + } + } + } + } + } + + private static void addBreakToContainer(Region c) { + if (RegionUtils.hasExitEdge(c)) { + return; + } + List insns = new ArrayList<>(1); + insns.add(new InsnNode(InsnType.BREAK, 0)); + c.add(new InsnContainer(insns)); + } + + private static void addBreakForBlock(MethodNode mth, IContainer c, Set blocks, BlockNode bn) { + for (BlockNode s : bn.getCleanSuccessors()) { + if (!blocks.contains(s) + && !bn.contains(AFlag.ADDED_TO_REGION) + && !s.contains(AFlag.FALL_THROUGH)) { + addBreak(mth, c, bn); + return; + } + } + } + + private static void addBreak(MethodNode mth, IContainer c, BlockNode bn) { + IContainer blockContainer = RegionUtils.getBlockContainer(c, bn); + if (blockContainer instanceof Region) { + addBreakToContainer((Region) blockContainer); + } else if (c instanceof Region) { + addBreakToContainer((Region) c); + } else { + LOG.warn("Can't insert break, container: {}, block: {}, mth: {}", blockContainer, bn, mth); + } + } + + private PostProcessRegions() { + // singleton + } +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java deleted file mode 100644 index 2112813e2..000000000 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java +++ /dev/null @@ -1,1187 +0,0 @@ -package jadx.core.dex.visitors.regions; - -import java.util.ArrayList; -import java.util.BitSet; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; - -import org.jetbrains.annotations.Nullable; - -import jadx.core.dex.attributes.AFlag; -import jadx.core.dex.attributes.AType; -import jadx.core.dex.attributes.nodes.EdgeInsnAttr; -import jadx.core.dex.attributes.nodes.LoopInfo; -import jadx.core.dex.attributes.nodes.LoopLabelAttr; -import jadx.core.dex.attributes.nodes.RegionRefAttr; -import jadx.core.dex.instructions.IfNode; -import jadx.core.dex.instructions.InsnType; -import jadx.core.dex.instructions.SwitchInsn; -import jadx.core.dex.instructions.args.InsnArg; -import jadx.core.dex.nodes.BlockNode; -import jadx.core.dex.nodes.Edge; -import jadx.core.dex.nodes.IBlock; -import jadx.core.dex.nodes.IContainer; -import jadx.core.dex.nodes.IRegion; -import jadx.core.dex.nodes.InsnContainer; -import jadx.core.dex.nodes.InsnNode; -import jadx.core.dex.nodes.MethodNode; -import jadx.core.dex.regions.Region; -import jadx.core.dex.regions.SwitchRegion; -import jadx.core.dex.regions.SynchronizedRegion; -import jadx.core.dex.regions.conditions.IfInfo; -import jadx.core.dex.regions.conditions.IfRegion; -import jadx.core.dex.regions.loops.LoopRegion; -import jadx.core.dex.trycatch.ExcHandlerAttr; -import jadx.core.dex.trycatch.ExceptionHandler; -import jadx.core.dex.trycatch.TryCatchBlockAttr; -import jadx.core.utils.BlockUtils; -import jadx.core.utils.ListUtils; -import jadx.core.utils.RegionUtils; -import jadx.core.utils.Utils; -import jadx.core.utils.exceptions.JadxOverflowException; -import jadx.core.utils.exceptions.JadxRuntimeException; - -import static jadx.core.dex.visitors.regions.IfMakerHelper.confirmMerge; -import static jadx.core.dex.visitors.regions.IfMakerHelper.makeIfInfo; -import static jadx.core.dex.visitors.regions.IfMakerHelper.mergeNestedIfNodes; -import static jadx.core.dex.visitors.regions.IfMakerHelper.searchNestedIf; -import static jadx.core.utils.BlockUtils.followEmptyPath; -import static jadx.core.utils.BlockUtils.getNextBlock; -import static jadx.core.utils.BlockUtils.isPathExists; - -public class RegionMaker { - private final MethodNode mth; - private final int regionsLimit; - private final BitSet processedBlocks; - private int regionsCount; - - public RegionMaker(MethodNode mth) { - this.mth = mth; - int blocksCount = mth.getBasicBlocks().size(); - this.processedBlocks = new BitSet(blocksCount); - this.regionsLimit = blocksCount * 100; - } - - public Region makeRegion(BlockNode startBlock, RegionStack stack) { - Region r = new Region(stack.peekRegion()); - if (startBlock == null) { - return r; - } - if (stack.containsExit(startBlock)) { - insertEdgeInsns(r, startBlock); - return r; - } - - int startBlockId = startBlock.getId(); - if (processedBlocks.get(startBlockId)) { - mth.addWarn("Removed duplicated region for block: " + startBlock + ' ' + startBlock.getAttributesString()); - return r; - } - processedBlocks.set(startBlockId); - - BlockNode next = startBlock; - while (next != null) { - next = traverse(r, next, stack); - regionsCount++; - if (regionsCount > regionsLimit) { - throw new JadxOverflowException("Regions count limit reached"); - } - } - return r; - } - - private void insertEdgeInsns(Region region, BlockNode exitBlock) { - List edgeInsns = exitBlock.getAll(AType.EDGE_INSN); - if (edgeInsns.isEmpty()) { - return; - } - List insns = new ArrayList<>(edgeInsns.size()); - addOneInsnOfType(insns, edgeInsns, InsnType.BREAK); - addOneInsnOfType(insns, edgeInsns, InsnType.CONTINUE); - region.add(new InsnContainer(insns)); - } - - private void addOneInsnOfType(List insns, List edgeInsns, InsnType insnType) { - for (EdgeInsnAttr edgeInsn : edgeInsns) { - InsnNode insn = edgeInsn.getInsn(); - if (insn.getType() == insnType) { - insns.add(insn); - return; - } - } - } - - /** - * Recursively traverse all blocks from 'block' until block from 'exits' - */ - private BlockNode traverse(IRegion r, BlockNode block, RegionStack stack) { - if (block.contains(AFlag.MTH_EXIT_BLOCK)) { - return null; - } - BlockNode next = null; - boolean processed = false; - - List loops = block.getAll(AType.LOOP); - int loopCount = loops.size(); - if (loopCount != 0 && block.contains(AFlag.LOOP_START)) { - if (loopCount == 1) { - next = processLoop(r, loops.get(0), stack); - processed = true; - } else { - for (LoopInfo loop : loops) { - if (loop.getStart() == block) { - next = processLoop(r, loop, stack); - processed = true; - break; - } - } - } - } - - InsnNode insn = BlockUtils.getLastInsn(block); - if (!processed && insn != null) { - switch (insn.getType()) { - case IF: - next = processIf(r, block, (IfNode) insn, stack); - processed = true; - break; - - case SWITCH: - next = processSwitch(r, block, (SwitchInsn) insn, stack); - processed = true; - break; - - case MONITOR_ENTER: - next = processMonitorEnter(r, block, insn, stack); - processed = true; - break; - - default: - break; - } - } - if (!processed) { - r.getSubBlocks().add(block); - next = getNextBlock(block); - } - if (next != null && !stack.containsExit(block) && !stack.containsExit(next)) { - return next; - } - return null; - } - - private BlockNode processLoop(IRegion curRegion, LoopInfo loop, RegionStack stack) { - BlockNode loopStart = loop.getStart(); - Set exitBlocksSet = loop.getExitNodes(); - - // set exit blocks scan order priority - // this can help if loop has several exits (after using 'break' or 'return' in loop) - List exitBlocks = new ArrayList<>(exitBlocksSet.size()); - BlockNode nextStart = getNextBlock(loopStart); - if (nextStart != null && exitBlocksSet.remove(nextStart)) { - exitBlocks.add(nextStart); - } - if (exitBlocksSet.remove(loopStart)) { - exitBlocks.add(loopStart); - } - if (exitBlocksSet.remove(loop.getEnd())) { - exitBlocks.add(loop.getEnd()); - } - exitBlocks.addAll(exitBlocksSet); - - LoopRegion loopRegion = makeLoopRegion(curRegion, loop, exitBlocks); - if (loopRegion == null) { - BlockNode exit = makeEndlessLoop(curRegion, stack, loop, loopStart); - insertContinue(loop); - return exit; - } - curRegion.getSubBlocks().add(loopRegion); - IRegion outerRegion = stack.peekRegion(); - stack.push(loopRegion); - - IfInfo condInfo = makeIfInfo(mth, loopRegion.getHeader()); - condInfo = searchNestedIf(condInfo); - confirmMerge(condInfo); - if (!loop.getLoopBlocks().contains(condInfo.getThenBlock())) { - // invert loop condition if 'then' points to exit - condInfo = IfInfo.invert(condInfo); - } - loopRegion.updateCondition(condInfo); - // prevent if's merge with loop condition - condInfo.getMergedBlocks().forEach(b -> b.add(AFlag.ADDED_TO_REGION)); - exitBlocks.removeAll(condInfo.getMergedBlocks()); - - if (!exitBlocks.isEmpty()) { - BlockNode loopExit = condInfo.getElseBlock(); - if (loopExit != null) { - // add 'break' instruction before path cross between main loop exit and sub-exit - for (Edge exitEdge : loop.getExitEdges()) { - if (exitBlocks.contains(exitEdge.getSource())) { - insertLoopBreak(stack, loop, loopExit, exitEdge); - } - } - } - } - - BlockNode out; - if (loopRegion.isConditionAtEnd()) { - BlockNode thenBlock = condInfo.getThenBlock(); - out = thenBlock == loop.getEnd() || thenBlock == loopStart ? condInfo.getElseBlock() : thenBlock; - out = BlockUtils.followEmptyPath(out); - loopStart.remove(AType.LOOP); - loop.getEnd().add(AFlag.ADDED_TO_REGION); - stack.addExit(loop.getEnd()); - processedBlocks.clear(loopStart.getId()); - Region body = makeRegion(loopStart, stack); - loopRegion.setBody(body); - loopStart.addAttr(AType.LOOP, loop); - loop.getEnd().remove(AFlag.ADDED_TO_REGION); - } else { - out = condInfo.getElseBlock(); - if (outerRegion != null - && out != null - && out.contains(AFlag.LOOP_START) - && !out.getAll(AType.LOOP).contains(loop) - && RegionUtils.isRegionContainsBlock(outerRegion, out)) { - // exit to already processed outer loop - out = null; - } - stack.addExit(out); - BlockNode loopBody = condInfo.getThenBlock(); - Region body; - if (Objects.equals(loopBody, loopStart)) { - // empty loop body - body = new Region(loopRegion); - } else { - body = makeRegion(loopBody, stack); - } - // add blocks from loop start to first condition block - BlockNode conditionBlock = condInfo.getFirstIfBlock(); - if (loopStart != conditionBlock) { - Set blocks = BlockUtils.getAllPathsBlocks(loopStart, conditionBlock); - blocks.remove(conditionBlock); - for (BlockNode block : blocks) { - if (block.getInstructions().isEmpty() - && !block.contains(AFlag.ADDED_TO_REGION) - && !RegionUtils.isRegionContainsBlock(body, block)) { - body.add(block); - } - } - } - loopRegion.setBody(body); - } - stack.pop(); - insertContinue(loop); - return out; - } - - /** - * Select loop exit and construct LoopRegion - */ - private LoopRegion makeLoopRegion(IRegion curRegion, LoopInfo loop, List exitBlocks) { - for (BlockNode block : exitBlocks) { - if (block.contains(AType.EXC_HANDLER)) { - continue; - } - InsnNode lastInsn = BlockUtils.getLastInsn(block); - if (lastInsn == null || lastInsn.getType() != InsnType.IF) { - continue; - } - List loops = block.getAll(AType.LOOP); - if (!loops.isEmpty() && loops.get(0) != loop) { - // skip nested loop condition - continue; - } - boolean exitAtLoopEnd = isExitAtLoopEnd(block, loop); - LoopRegion loopRegion = new LoopRegion(curRegion, loop, block, exitAtLoopEnd); - boolean found; - if (block == loop.getStart() || exitAtLoopEnd - || BlockUtils.isEmptySimplePath(loop.getStart(), block)) { - found = true; - } else if (block.getPredecessors().contains(loop.getStart())) { - loopRegion.setPreCondition(loop.getStart()); - // if we can't merge pre-condition this is not correct header - found = loopRegion.checkPreCondition(); - } else { - found = false; - } - if (found) { - List list = mth.getAllLoopsForBlock(block); - if (list.size() >= 2) { - // bad condition if successors going out of all loops - boolean allOuter = true; - for (BlockNode outerBlock : block.getCleanSuccessors()) { - List outLoopList = mth.getAllLoopsForBlock(outerBlock); - outLoopList.remove(loop); - if (!outLoopList.isEmpty()) { - // goes to outer loop - allOuter = false; - break; - } - } - if (allOuter) { - found = false; - } - } - } - if (found && !checkLoopExits(loop, block)) { - found = false; - } - if (found) { - return loopRegion; - } - } - // no exit found => endless loop - return null; - } - - private static boolean isExitAtLoopEnd(BlockNode exit, LoopInfo loop) { - BlockNode loopEnd = loop.getEnd(); - if (exit == loopEnd) { - return true; - } - BlockNode loopStart = loop.getStart(); - if (loopStart.getInstructions().isEmpty() && ListUtils.isSingleElement(loopStart.getSuccessors(), exit)) { - return false; - } - return loopEnd.getInstructions().isEmpty() && ListUtils.isSingleElement(loopEnd.getPredecessors(), exit); - } - - private boolean checkLoopExits(LoopInfo loop, BlockNode mainExitBlock) { - List exitEdges = loop.getExitEdges(); - if (exitEdges.size() < 2) { - return true; - } - Optional mainEdgeOpt = exitEdges.stream().filter(edge -> edge.getSource() == mainExitBlock).findFirst(); - if (mainEdgeOpt.isEmpty()) { - throw new JadxRuntimeException("Not found exit edge by exit block: " + mainExitBlock); - } - Edge mainExitEdge = mainEdgeOpt.get(); - BlockNode mainOutBlock = mainExitEdge.getTarget(); - for (Edge exitEdge : exitEdges) { - if (exitEdge != mainExitEdge) { - // all exit paths must be same or don't cross (will be inside loop) - BlockNode exitBlock = exitEdge.getTarget(); - if (!isEqualPaths(mainOutBlock, exitBlock)) { - BlockNode crossBlock = BlockUtils.getPathCross(mth, mainOutBlock, exitBlock); - if (crossBlock != null) { - return false; - } - } - } - } - return true; - } - - private BlockNode makeEndlessLoop(IRegion curRegion, RegionStack stack, LoopInfo loop, BlockNode loopStart) { - LoopRegion loopRegion = new LoopRegion(curRegion, loop, null, false); - curRegion.getSubBlocks().add(loopRegion); - - loopStart.remove(AType.LOOP); - processedBlocks.clear(loopStart.getId()); - stack.push(loopRegion); - - BlockNode out = null; - // insert 'break' for exits - List exitEdges = loop.getExitEdges(); - if (exitEdges.size() == 1) { - Edge exitEdge = exitEdges.get(0); - BlockNode exit = exitEdge.getTarget(); - if (insertLoopBreak(stack, loop, exit, exitEdge)) { - BlockNode nextBlock = getNextBlock(exit); - if (nextBlock != null) { - stack.addExit(nextBlock); - out = nextBlock; - } - } - } else { - for (Edge exitEdge : exitEdges) { - BlockNode exit = exitEdge.getTarget(); - List blocks = BlockUtils.bitSetToBlocks(mth, exit.getDomFrontier()); - for (BlockNode block : blocks) { - if (BlockUtils.isPathExists(exit, block)) { - stack.addExit(block); - insertLoopBreak(stack, loop, block, exitEdge); - out = block; - } else { - insertLoopBreak(stack, loop, exit, exitEdge); - } - } - } - } - - Region body = makeRegion(loopStart, stack); - BlockNode loopEnd = loop.getEnd(); - if (!RegionUtils.isRegionContainsBlock(body, loopEnd) - && !loopEnd.contains(AType.EXC_HANDLER) - && !inExceptionHandlerBlocks(loopEnd)) { - body.getSubBlocks().add(loopEnd); - } - loopRegion.setBody(body); - - if (out == null) { - BlockNode next = getNextBlock(loopEnd); - out = RegionUtils.isRegionContainsBlock(body, next) ? null : next; - } - stack.pop(); - loopStart.addAttr(AType.LOOP, loop); - return out; - } - - private boolean inExceptionHandlerBlocks(BlockNode loopEnd) { - if (mth.getExceptionHandlersCount() == 0) { - return false; - } - for (ExceptionHandler eh : mth.getExceptionHandlers()) { - if (eh.getBlocks().contains(loopEnd)) { - return true; - } - } - return false; - } - - private boolean canInsertBreak(BlockNode exit) { - if (BlockUtils.containsExitInsn(exit)) { - return false; - } - List simplePath = BlockUtils.buildSimplePath(exit); - if (!simplePath.isEmpty()) { - BlockNode lastBlock = simplePath.get(simplePath.size() - 1); - if (lastBlock.isMthExitBlock() - || lastBlock.isReturnBlock() - || mth.isPreExitBlock(lastBlock)) { - return false; - } - } - // check if there no outer switch (TODO: very expensive check) - Set paths = BlockUtils.getAllPathsBlocks(mth.getEnterBlock(), exit); - for (BlockNode block : paths) { - if (BlockUtils.checkLastInsnType(block, InsnType.SWITCH)) { - return false; - } - } - return true; - } - - private boolean insertLoopBreak(RegionStack stack, LoopInfo loop, BlockNode loopExit, Edge exitEdge) { - BlockNode exit = exitEdge.getTarget(); - Edge insertEdge = null; - boolean confirm = false; - // process special cases: - // 1. jump to outer loop - BlockNode exitEnd = BlockUtils.followEmptyPath(exit); - List loops = exitEnd.getAll(AType.LOOP); - for (LoopInfo loopAtEnd : loops) { - if (loopAtEnd != loop && loop.hasParent(loopAtEnd)) { - insertEdge = exitEdge; - confirm = true; - break; - } - } - - if (!confirm) { - BlockNode insertBlock = null; - while (exit != null) { - if (insertBlock != null && isPathExists(loopExit, exit)) { - // found cross - if (canInsertBreak(insertBlock)) { - insertEdge = new Edge(insertBlock, insertBlock.getSuccessors().get(0)); - confirm = true; - break; - } - return false; - } - insertBlock = exit; - List cs = exit.getCleanSuccessors(); - exit = cs.size() == 1 ? cs.get(0) : null; - } - } - if (!confirm) { - return false; - } - InsnNode breakInsn = new InsnNode(InsnType.BREAK, 0); - breakInsn.addAttr(AType.LOOP, loop); - EdgeInsnAttr.addEdgeInsn(insertEdge, breakInsn); - stack.addExit(exit); - // add label to 'break' if needed - addBreakLabel(exitEdge, exit, breakInsn); - return true; - } - - private void addBreakLabel(Edge exitEdge, BlockNode exit, InsnNode breakInsn) { - BlockNode outBlock = BlockUtils.getNextBlock(exitEdge.getTarget()); - if (outBlock == null) { - return; - } - List exitLoop = mth.getAllLoopsForBlock(outBlock); - if (!exitLoop.isEmpty()) { - return; - } - List inLoops = mth.getAllLoopsForBlock(exitEdge.getSource()); - if (inLoops.size() < 2) { - return; - } - // search for parent loop - LoopInfo parentLoop = null; - for (LoopInfo loop : inLoops) { - if (loop.getParentLoop() == null) { - parentLoop = loop; - break; - } - } - if (parentLoop == null) { - return; - } - if (parentLoop.getEnd() != exit && !parentLoop.getExitNodes().contains(exit)) { - LoopLabelAttr labelAttr = new LoopLabelAttr(parentLoop); - breakInsn.addAttr(labelAttr); - parentLoop.getStart().addAttr(labelAttr); - } - } - - private static void insertContinue(LoopInfo loop) { - BlockNode loopEnd = loop.getEnd(); - List predecessors = loopEnd.getPredecessors(); - if (predecessors.size() <= 1) { - return; - } - Set loopExitNodes = loop.getExitNodes(); - for (BlockNode pred : predecessors) { - if (canInsertContinue(pred, predecessors, loopEnd, loopExitNodes)) { - InsnNode cont = new InsnNode(InsnType.CONTINUE, 0); - pred.getInstructions().add(cont); - } - } - } - - private static boolean canInsertContinue(BlockNode pred, List predecessors, BlockNode loopEnd, - Set loopExitNodes) { - if (!pred.contains(AFlag.SYNTHETIC) - || BlockUtils.checkLastInsnType(pred, InsnType.CONTINUE)) { - return false; - } - List preds = pred.getPredecessors(); - if (preds.isEmpty()) { - return false; - } - BlockNode codePred = preds.get(0); - if (codePred.contains(AFlag.ADDED_TO_REGION)) { - return false; - } - if (loopEnd.isDominator(codePred) - || loopExitNodes.contains(codePred)) { - return false; - } - if (isDominatedOnBlocks(codePred, predecessors)) { - return false; - } - boolean gotoExit = false; - for (BlockNode exit : loopExitNodes) { - if (BlockUtils.isPathExists(codePred, exit)) { - gotoExit = true; - break; - } - } - return gotoExit; - } - - private static boolean isDominatedOnBlocks(BlockNode dom, List blocks) { - for (BlockNode node : blocks) { - if (!node.isDominator(dom)) { - return false; - } - } - return true; - } - - private BlockNode processMonitorEnter(IRegion curRegion, BlockNode block, InsnNode insn, RegionStack stack) { - SynchronizedRegion synchRegion = new SynchronizedRegion(curRegion, insn); - synchRegion.getSubBlocks().add(block); - curRegion.getSubBlocks().add(synchRegion); - - Set exits = new LinkedHashSet<>(); - Set cacheSet = new HashSet<>(); - traverseMonitorExits(synchRegion, insn.getArg(0), block, exits, cacheSet); - - for (InsnNode exitInsn : synchRegion.getExitInsns()) { - BlockNode insnBlock = BlockUtils.getBlockByInsn(mth, exitInsn); - if (insnBlock != null) { - insnBlock.add(AFlag.DONT_GENERATE); - } - // remove arg from MONITOR_EXIT to allow inline in MONITOR_ENTER - exitInsn.removeArg(0); - exitInsn.add(AFlag.DONT_GENERATE); - } - - BlockNode body = getNextBlock(block); - if (body == null) { - mth.addWarn("Unexpected end of synchronized block"); - return null; - } - BlockNode exit = null; - if (exits.size() == 1) { - exit = getNextBlock(exits.iterator().next()); - } else if (exits.size() > 1) { - cacheSet.clear(); - exit = traverseMonitorExitsCross(body, exits, cacheSet); - } - - stack.push(synchRegion); - if (exit != null) { - stack.addExit(exit); - } else { - for (BlockNode exitBlock : exits) { - // don't add exit blocks which leads to method end blocks ('return', 'throw', etc) - List list = BlockUtils.buildSimplePath(exitBlock); - if (list.isEmpty() || !BlockUtils.isExitBlock(mth, Utils.last(list))) { - stack.addExit(exitBlock); - // we can still try using this as an exit block to make sure it's visited. - exit = exitBlock; - } - } - } - synchRegion.getSubBlocks().add(makeRegion(body, stack)); - stack.pop(); - return exit; - } - - /** - * Traverse from monitor-enter thru successors and collect blocks contains monitor-exit - */ - private static void traverseMonitorExits(SynchronizedRegion region, InsnArg arg, BlockNode block, Set exits, - Set visited) { - visited.add(block); - for (InsnNode insn : block.getInstructions()) { - if (insn.getType() == InsnType.MONITOR_EXIT - && insn.getArgsCount() > 0 - && insn.getArg(0).equals(arg)) { - exits.add(block); - region.getExitInsns().add(insn); - return; - } - } - for (BlockNode node : block.getSuccessors()) { - if (!visited.contains(node)) { - traverseMonitorExits(region, arg, node, exits, visited); - } - } - } - - /** - * Traverse from monitor-enter thru successors and search for exit paths cross - */ - private static BlockNode traverseMonitorExitsCross(BlockNode block, Set exits, Set visited) { - visited.add(block); - for (BlockNode node : block.getCleanSuccessors()) { - boolean cross = true; - for (BlockNode exitBlock : exits) { - boolean p = isPathExists(exitBlock, node); - if (!p) { - cross = false; - break; - } - } - if (cross) { - return node; - } - if (!visited.contains(node)) { - BlockNode res = traverseMonitorExitsCross(node, exits, visited); - if (res != null) { - return res; - } - } - } - return null; - } - - private BlockNode processIf(IRegion currentRegion, BlockNode block, IfNode ifnode, RegionStack stack) { - if (block.contains(AFlag.ADDED_TO_REGION)) { - // block already included in other 'if' region - return ifnode.getThenBlock(); - } - - IfInfo currentIf = makeIfInfo(mth, block); - if (currentIf == null) { - return null; - } - IfInfo mergedIf = mergeNestedIfNodes(currentIf); - if (mergedIf != null) { - currentIf = mergedIf; - } else { - // invert simple condition (compiler often do it) - currentIf = IfInfo.invert(currentIf); - } - IfInfo modifiedIf = IfMakerHelper.restructureIf(mth, block, currentIf); - if (modifiedIf != null) { - currentIf = modifiedIf; - } else { - if (currentIf.getMergedBlocks().size() <= 1) { - return null; - } - currentIf = makeIfInfo(mth, block); - currentIf = IfMakerHelper.restructureIf(mth, block, currentIf); - if (currentIf == null) { - // all attempts failed - return null; - } - } - confirmMerge(currentIf); - - IfRegion ifRegion = new IfRegion(currentRegion); - ifRegion.updateCondition(currentIf); - currentRegion.getSubBlocks().add(ifRegion); - - BlockNode outBlock = currentIf.getOutBlock(); - stack.push(ifRegion); - stack.addExit(outBlock); - - ifRegion.setThenRegion(makeRegion(currentIf.getThenBlock(), stack)); - BlockNode elseBlock = currentIf.getElseBlock(); - if (elseBlock == null || stack.containsExit(elseBlock)) { - ifRegion.setElseRegion(null); - } else { - ifRegion.setElseRegion(makeRegion(elseBlock, stack)); - } - - // insert edge insns in new 'else' branch - // TODO: make more common algorithm - if (ifRegion.getElseRegion() == null && outBlock != null) { - List edgeInsnAttrs = outBlock.getAll(AType.EDGE_INSN); - if (!edgeInsnAttrs.isEmpty()) { - Region elseRegion = new Region(ifRegion); - for (EdgeInsnAttr edgeInsnAttr : edgeInsnAttrs) { - if (edgeInsnAttr.getEnd().equals(outBlock)) { - addEdgeInsn(currentIf, elseRegion, edgeInsnAttr); - } - } - ifRegion.setElseRegion(elseRegion); - } - } - - stack.pop(); - return outBlock; - } - - private void addEdgeInsn(IfInfo ifInfo, Region region, EdgeInsnAttr edgeInsnAttr) { - BlockNode start = edgeInsnAttr.getStart(); - boolean fromThisIf = false; - for (BlockNode ifBlock : ifInfo.getMergedBlocks()) { - if (ifBlock.getSuccessors().contains(start)) { - fromThisIf = true; - break; - } - } - if (!fromThisIf) { - return; - } - region.add(start); - } - - private BlockNode processSwitch(IRegion currentRegion, BlockNode block, SwitchInsn insn, RegionStack stack) { - // map case blocks to keys - int len = insn.getTargets().length; - Map> blocksMap = new LinkedHashMap<>(len); - BlockNode[] targetBlocksArr = insn.getTargetBlocks(); - for (int i = 0; i < len; i++) { - List keys = blocksMap.computeIfAbsent(targetBlocksArr[i], k -> new ArrayList<>(2)); - keys.add(insn.getKey(i)); - } - BlockNode defCase = insn.getDefTargetBlock(); - if (defCase != null) { - List keys = blocksMap.computeIfAbsent(defCase, k -> new ArrayList<>(1)); - keys.add(SwitchRegion.DEFAULT_CASE_KEY); - } - - SwitchRegion sw = new SwitchRegion(currentRegion, block); - insn.addAttr(new RegionRefAttr(sw)); - currentRegion.getSubBlocks().add(sw); - stack.push(sw); - - BlockNode out = calcSwitchOut(block, stack); - stack.addExit(out); - - processFallThroughCases(sw, out, stack, blocksMap); - removeEmptyCases(insn, sw, defCase); - - stack.pop(); - return out; - } - - private void processFallThroughCases(SwitchRegion sw, @Nullable BlockNode out, - RegionStack stack, Map> blocksMap) { - Map fallThroughCases = new LinkedHashMap<>(); - if (out != null) { - // detect fallthrough cases - BitSet caseBlocks = BlockUtils.blocksToBitSet(mth, blocksMap.keySet()); - caseBlocks.clear(out.getId()); - for (BlockNode successor : sw.getHeader().getCleanSuccessors()) { - BitSet df = successor.getDomFrontier(); - if (df.intersects(caseBlocks)) { - BlockNode fallThroughBlock = getOneIntersectionBlock(out, caseBlocks, df); - fallThroughCases.put(successor, fallThroughBlock); - } - } - // check fallthrough cases order - if (!fallThroughCases.isEmpty() && isBadCasesOrder(blocksMap, fallThroughCases)) { - Map> newBlocksMap = reOrderSwitchCases(blocksMap, fallThroughCases); - if (isBadCasesOrder(newBlocksMap, fallThroughCases)) { - mth.addWarnComment("Can't fix incorrect switch cases order, some code will duplicate"); - fallThroughCases.clear(); - } else { - blocksMap = newBlocksMap; - } - } - } - - for (Entry> entry : blocksMap.entrySet()) { - List keysList = entry.getValue(); - BlockNode caseBlock = entry.getKey(); - if (stack.containsExit(caseBlock)) { - sw.addCase(keysList, new Region(stack.peekRegion())); - } else { - BlockNode next = fallThroughCases.get(caseBlock); - stack.addExit(next); - Region caseRegion = makeRegion(caseBlock, stack); - stack.removeExit(next); - if (next != null) { - next.add(AFlag.FALL_THROUGH); - caseRegion.add(AFlag.FALL_THROUGH); - } - sw.addCase(keysList, caseRegion); - // 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor - } - } - } - - @Nullable - private BlockNode getOneIntersectionBlock(BlockNode out, BitSet caseBlocks, BitSet fallThroughSet) { - BitSet caseExits = BlockUtils.copyBlocksBitSet(mth, fallThroughSet); - caseExits.clear(out.getId()); - caseExits.and(caseBlocks); - return BlockUtils.bitSetToOneBlock(mth, caseExits); - } - - private @Nullable BlockNode calcSwitchOut(BlockNode block, RegionStack stack) { - // union of case blocks dominance frontier - // works if no fallthrough cases and no returns inside switch - BitSet outs = BlockUtils.newBlocksBitSet(mth); - for (BlockNode s : block.getCleanSuccessors()) { - if (s.contains(AFlag.LOOP_END)) { - // loop end dom frontier is loop start, ignore it - continue; - } - outs.or(s.getDomFrontier()); - } - outs.clear(block.getId()); - outs.clear(mth.getExitBlock().getId()); - if (outs.isEmpty()) { - // switch already contains method exit - // add everything, out block not needed - return mth.getExitBlock(); - } - - BlockNode out = null; - if (outs.cardinality() == 1) { - // single exit - out = BlockUtils.bitSetToOneBlock(mth, outs); - } else { - // several switch exits - // possible 'return', 'continue' or fallthrough in one of the cases - LoopInfo loop = mth.getLoopForBlock(block); - if (loop != null) { - outs.andNot(loop.getStart().getPostDoms()); - outs.andNot(loop.getEnd().getPostDoms()); - BlockNode loopEnd = loop.getEnd(); - if (outs.cardinality() == 2 && outs.get(loopEnd.getId())) { - // insert 'continue' for cases lead to loop end - // expect only 2 exits: loop end and switch out - List outList = BlockUtils.bitSetToBlocks(mth, outs); - outList.remove(loopEnd); - BlockNode possibleOut = Utils.getOne(outList); - if (possibleOut != null && insertContinueInSwitch(block, possibleOut, loopEnd)) { - outs.clear(loopEnd.getId()); - out = possibleOut; - } - } - } - if (out == null) { - BlockNode imPostDom = block.getIPostDom(); - if (outs.get(imPostDom.getId())) { - out = imPostDom; - } else { - outs.andNot(block.getPostDoms()); - out = BlockUtils.bitSetToOneBlock(mth, outs); - } - } - } - if (out != null && mth.isPreExitBlock(out)) { - // include 'return' or 'throw' in case blocks - out = mth.getExitBlock(); - } - BlockNode imPostDom = block.getIPostDom(); - if (out != imPostDom && !mth.isPreExitBlock(imPostDom)) { - // stop other paths at common exit - stack.addExit(imPostDom); - } - if (block.getCleanSuccessors().contains(imPostDom)) { - // add exit to stop on empty 'default' block - stack.addExit(imPostDom); - } - if (out == null) { - mth.addWarnComment("Failed to find 'out' block for switch in " + block + ". Please report as an issue."); - // fallback option; should work in most cases - out = block.getIPostDom(); - } - if (out != null && processedBlocks.get(out.getId())) { - // 'out' block already processed, prevent endless loop - throw new JadxRuntimeException("Failed to find switch 'out' block (already processed)"); - } - return out; - } - - /** - * Remove empty case blocks: - * 1. single 'default' case - * 2. filler cases if switch is 'packed' and 'default' case is empty - */ - private void removeEmptyCases(SwitchInsn insn, SwitchRegion sw, BlockNode defCase) { - boolean defaultCaseIsEmpty; - if (defCase == null) { - defaultCaseIsEmpty = true; - } else { - defaultCaseIsEmpty = sw.getCases().stream() - .anyMatch(c -> c.getKeys().contains(SwitchRegion.DEFAULT_CASE_KEY) - && RegionUtils.isEmpty(c.getContainer())); - } - if (defaultCaseIsEmpty) { - sw.getCases().removeIf(caseInfo -> { - if (RegionUtils.isEmpty(caseInfo.getContainer())) { - List keys = caseInfo.getKeys(); - if (keys.contains(SwitchRegion.DEFAULT_CASE_KEY)) { - return true; - } - if (insn.isPacked()) { - return true; - } - } - return false; - }); - } - } - - private boolean isBadCasesOrder(Map> blocksMap, Map fallThroughCases) { - BlockNode nextCaseBlock = null; - for (BlockNode caseBlock : blocksMap.keySet()) { - if (nextCaseBlock != null && !caseBlock.equals(nextCaseBlock)) { - return true; - } - nextCaseBlock = fallThroughCases.get(caseBlock); - } - return nextCaseBlock != null; - } - - private Map> reOrderSwitchCases(Map> blocksMap, - Map fallThroughCases) { - List list = new ArrayList<>(blocksMap.size()); - list.addAll(blocksMap.keySet()); - list.sort((a, b) -> { - BlockNode nextA = fallThroughCases.get(a); - if (nextA != null) { - if (b.equals(nextA)) { - return -1; - } - } else if (a.equals(fallThroughCases.get(b))) { - return 1; - } - return 0; - }); - - Map> newBlocksMap = new LinkedHashMap<>(blocksMap.size()); - for (BlockNode key : list) { - newBlocksMap.put(key, blocksMap.get(key)); - } - return newBlocksMap; - } - - private boolean insertContinueInSwitch(BlockNode switchBlock, BlockNode switchOut, BlockNode loopEnd) { - boolean inserted = false; - for (BlockNode caseBlock : switchBlock.getCleanSuccessors()) { - if (caseBlock.getDomFrontier().get(loopEnd.getId()) && caseBlock != switchOut) { - // search predecessor of loop end on path from this successor - Set list = new HashSet<>(BlockUtils.collectBlocksDominatedBy(mth, caseBlock, caseBlock)); - if (list.contains(switchOut) || switchOut.getPredecessors().stream().anyMatch(list::contains)) { - // 'continue' not needed - } else { - for (BlockNode p : loopEnd.getPredecessors()) { - if (list.contains(p)) { - if (p.isSynthetic()) { - p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0)); - inserted = true; - } - break; - } - } - } - } - } - return inserted; - } - - public IRegion processTryCatchBlocks(MethodNode mth) { - List tcs = mth.getAll(AType.TRY_BLOCKS_LIST); - for (TryCatchBlockAttr tc : tcs) { - List blocks = new ArrayList<>(tc.getHandlersCount()); - Set splitters = new HashSet<>(); - for (ExceptionHandler handler : tc.getHandlers()) { - BlockNode handlerBlock = handler.getHandlerBlock(); - if (handlerBlock != null) { - blocks.add(handlerBlock); - splitters.add(BlockUtils.getTopSplitterForHandler(handlerBlock)); - } else { - mth.addDebugComment("No exception handler block: " + handler); - } - } - Set exits = new HashSet<>(); - for (BlockNode splitter : splitters) { - for (BlockNode handler : blocks) { - if (handler.contains(AFlag.REMOVE)) { - continue; - } - List s = splitter.getSuccessors(); - if (s.isEmpty()) { - mth.addDebugComment("No successors for splitter: " + splitter); - continue; - } - BlockNode ss = s.get(0); - BlockNode cross = BlockUtils.getPathCross(mth, ss, handler); - if (cross != null && cross != ss && cross != handler) { - exits.add(cross); - } - } - } - for (ExceptionHandler handler : tc.getHandlers()) { - processExcHandler(mth, handler, exits); - } - } - return processHandlersOutBlocks(mth, tcs); - } - - /** - * Search handlers successor blocks not included in any region. - */ - protected IRegion processHandlersOutBlocks(MethodNode mth, List tcs) { - Set allRegionBlocks = new HashSet<>(); - RegionUtils.getAllRegionBlocks(mth.getRegion(), allRegionBlocks); - - Set succBlocks = new HashSet<>(); - for (TryCatchBlockAttr tc : tcs) { - for (ExceptionHandler handler : tc.getHandlers()) { - IContainer region = handler.getHandlerRegion(); - if (region != null) { - IBlock lastBlock = RegionUtils.getLastBlock(region); - if (lastBlock instanceof BlockNode) { - succBlocks.addAll(((BlockNode) lastBlock).getSuccessors()); - } - RegionUtils.getAllRegionBlocks(region, allRegionBlocks); - } - } - } - succBlocks.removeAll(allRegionBlocks); - if (succBlocks.isEmpty()) { - return null; - } - Region excOutRegion = new Region(mth.getRegion()); - for (IBlock block : succBlocks) { - if (block instanceof BlockNode) { - excOutRegion.add(makeRegion((BlockNode) block, new RegionStack(mth))); - } - } - return excOutRegion; - } - - private void processExcHandler(MethodNode mth, ExceptionHandler handler, Set exits) { - BlockNode start = handler.getHandlerBlock(); - if (start == null) { - return; - } - RegionStack stack = new RegionStack(this.mth); - BlockNode dom; - if (handler.isFinally()) { - dom = BlockUtils.getTopSplitterForHandler(start); - } else { - dom = start; - stack.addExits(exits); - } - if (dom.contains(AFlag.REMOVE)) { - return; - } - BitSet domFrontier = dom.getDomFrontier(); - List handlerExits = BlockUtils.bitSetToBlocks(this.mth, domFrontier); - boolean inLoop = this.mth.getLoopForBlock(start) != null; - for (BlockNode exit : handlerExits) { - if ((!inLoop || BlockUtils.isPathExists(start, exit)) - && RegionUtils.isRegionContainsBlock(this.mth.getRegion(), exit)) { - stack.addExit(exit); - } - } - handler.setHandlerRegion(makeRegion(start, stack)); - - ExcHandlerAttr excHandlerAttr = start.get(AType.EXC_HANDLER); - if (excHandlerAttr == null) { - mth.addWarn("Missing exception handler attribute for start block: " + start); - } else { - handler.getHandlerRegion().addAttr(excHandlerAttr); - } - } - - static boolean isEqualPaths(BlockNode b1, BlockNode b2) { - if (b1 == b2) { - return true; - } - if (b1 == null || b2 == null) { - return false; - } - return isEqualReturnBlocks(b1, b2) || isEmptySyntheticPath(b1, b2); - } - - private static boolean isEmptySyntheticPath(BlockNode b1, BlockNode b2) { - BlockNode n1 = followEmptyPath(b1); - BlockNode n2 = followEmptyPath(b2); - return n1 == n2 || isEqualReturnBlocks(n1, n2); - } - - public static boolean isEqualReturnBlocks(BlockNode b1, BlockNode b2) { - if (!b1.isReturnBlock() || !b2.isReturnBlock()) { - return false; - } - List b1Insns = b1.getInstructions(); - List b2Insns = b2.getInstructions(); - if (b1Insns.size() != 1 || b2Insns.size() != 1) { - return false; - } - InsnNode i1 = b1Insns.get(0); - InsnNode i2 = b2Insns.get(0); - if (i1.getArgsCount() != i2.getArgsCount()) { - return false; - } - if (i1.getArgsCount() == 0) { - return true; - } - InsnArg firstArg = i1.getArg(0); - InsnArg secondArg = i2.getArg(0); - if (firstArg.isSameConst(secondArg)) { - return true; - } - if (i1.getSourceLine() != i2.getSourceLine()) { - return false; - } - return firstArg.equals(secondArg); - } -} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMakerVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMakerVisitor.java index a26c28a90..00d667d72 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMakerVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMakerVisitor.java @@ -1,43 +1,20 @@ package jadx.core.dex.visitors.regions; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import jadx.core.dex.attributes.AFlag; -import jadx.core.dex.attributes.AType; -import jadx.core.dex.attributes.nodes.EdgeInsnAttr; -import jadx.core.dex.instructions.InsnType; -import jadx.core.dex.nodes.BlockNode; -import jadx.core.dex.nodes.IBlock; -import jadx.core.dex.nodes.IContainer; -import jadx.core.dex.nodes.IRegion; -import jadx.core.dex.nodes.InsnContainer; -import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.MethodNode; -import jadx.core.dex.regions.Region; -import jadx.core.dex.regions.SwitchRegion; -import jadx.core.dex.regions.SynchronizedRegion; -import jadx.core.dex.regions.loops.LoopRegion; import jadx.core.dex.visitors.AbstractVisitor; +import jadx.core.dex.visitors.JadxVisitor; +import jadx.core.dex.visitors.regions.maker.ExcHandlersRegionMaker; +import jadx.core.dex.visitors.regions.maker.RegionMaker; +import jadx.core.dex.visitors.regions.maker.SynchronizedRegionMaker; import jadx.core.dex.visitors.shrink.CodeShrinkVisitor; -import jadx.core.utils.InsnRemover; -import jadx.core.utils.RegionUtils; -import jadx.core.utils.Utils; import jadx.core.utils.exceptions.JadxException; -/** - * Pack blocks into regions for code generation - */ +@JadxVisitor( + name = "RegionMakerVisitor", + desc = "Pack blocks into regions for code generation" +) public class RegionMakerVisitor extends AbstractVisitor { - private static final Logger LOG = LoggerFactory.getLogger(RegionMakerVisitor.class); - - private static final IRegionVisitor POST_REGION_VISITOR = new PostRegionVisitor(); @Override public void visit(MethodNode mth) throws JadxException { @@ -45,33 +22,16 @@ public class RegionMakerVisitor extends AbstractVisitor { return; } RegionMaker rm = new RegionMaker(mth); - RegionStack state = new RegionStack(mth); - - // fill region structure - BlockNode startBlock = Utils.first(mth.getEnterBlock().getCleanSuccessors()); - mth.setRegion(rm.makeRegion(startBlock, state)); - + mth.setRegion(rm.makeMthRegion()); if (!mth.isNoExceptionHandlers()) { - IRegion expOutBlock = rm.processTryCatchBlocks(mth); - if (expOutBlock != null) { - mth.getRegion().add(expOutBlock); - } + new ExcHandlersRegionMaker(mth, rm).process(); } - postProcessRegions(mth); - } - - private static void postProcessRegions(MethodNode mth) { processForceInlineInsns(mth); - - // make try-catch regions ProcessTryCatchRegions.process(mth); - - DepthRegionTraversal.traverse(mth, POST_REGION_VISITOR); - + PostProcessRegions.process(mth); CleanRegions.process(mth); - if (mth.getAccessFlags().isSynchronized()) { - removeSynchronized(mth); + SynchronizedRegionMaker.removeSynchronized(mth); } } @@ -84,120 +44,8 @@ public class RegionMakerVisitor extends AbstractVisitor { } } - private static final class PostRegionVisitor extends AbstractRegionVisitor { - @Override - public void leaveRegion(MethodNode mth, IRegion region) { - if (region instanceof LoopRegion) { - // merge conditions in loops - LoopRegion loop = (LoopRegion) region; - loop.mergePreCondition(); - } else if (region instanceof SwitchRegion) { - // insert 'break' in switch cases (run after try/catch insertion) - processSwitch(mth, (SwitchRegion) region); - } else if (region instanceof Region) { - insertEdgeInsn((Region) region); - } - } - - /** - * Insert insn block from edge insn attribute. - */ - private static void insertEdgeInsn(Region region) { - List subBlocks = region.getSubBlocks(); - if (subBlocks.isEmpty()) { - return; - } - IContainer last = subBlocks.get(subBlocks.size() - 1); - List edgeInsnAttrs = last.getAll(AType.EDGE_INSN); - if (edgeInsnAttrs.isEmpty()) { - return; - } - EdgeInsnAttr insnAttr = edgeInsnAttrs.get(0); - if (!insnAttr.getStart().equals(last)) { - return; - } - if (last instanceof BlockNode) { - BlockNode block = (BlockNode) last; - if (block.getInstructions().isEmpty()) { - block.getInstructions().add(insnAttr.getInsn()); - return; - } - } - List insns = Collections.singletonList(insnAttr.getInsn()); - region.add(new InsnContainer(insns)); - } - - private static void processSwitch(MethodNode mth, SwitchRegion sw) { - for (IContainer c : sw.getBranches()) { - if (c instanceof Region) { - Set blocks = new HashSet<>(); - RegionUtils.getAllRegionBlocks(c, blocks); - if (blocks.isEmpty()) { - addBreakToContainer((Region) c); - } else { - for (IBlock block : blocks) { - if (block instanceof BlockNode) { - addBreakForBlock(mth, c, blocks, (BlockNode) block); - } - } - } - } - } - } - - private static void addBreakToContainer(Region c) { - if (RegionUtils.hasExitEdge(c)) { - return; - } - List insns = new ArrayList<>(1); - insns.add(new InsnNode(InsnType.BREAK, 0)); - c.add(new InsnContainer(insns)); - } - - private static void addBreakForBlock(MethodNode mth, IContainer c, Set blocks, BlockNode bn) { - for (BlockNode s : bn.getCleanSuccessors()) { - if (!blocks.contains(s) - && !bn.contains(AFlag.ADDED_TO_REGION) - && !s.contains(AFlag.FALL_THROUGH)) { - addBreak(mth, c, bn); - return; - } - } - } - - private static void addBreak(MethodNode mth, IContainer c, BlockNode bn) { - IContainer blockContainer = RegionUtils.getBlockContainer(c, bn); - if (blockContainer instanceof Region) { - addBreakToContainer((Region) blockContainer); - } else if (c instanceof Region) { - addBreakToContainer((Region) c); - } else { - LOG.warn("Can't insert break, container: {}, block: {}, mth: {}", blockContainer, bn, mth); - } - } - } - - private static void removeSynchronized(MethodNode mth) { - Region startRegion = mth.getRegion(); - List subBlocks = startRegion.getSubBlocks(); - if (!subBlocks.isEmpty() && subBlocks.get(0) instanceof SynchronizedRegion) { - SynchronizedRegion synchRegion = (SynchronizedRegion) subBlocks.get(0); - InsnNode synchInsn = synchRegion.getEnterInsn(); - if (!synchInsn.getArg(0).isThis()) { - LOG.warn("In synchronized method {}, top region not synchronized by 'this' {}", mth, synchInsn); - return; - } - // replace synchronized block with inner region - startRegion.getSubBlocks().set(0, synchRegion.getRegion()); - // remove 'monitor-enter' instruction - InsnRemover.remove(mth, synchInsn); - // remove 'monitor-exit' instruction - for (InsnNode exit : synchRegion.getExitInsns()) { - InsnRemover.remove(mth, exit); - } - // run region cleaner again - CleanRegions.process(mth); - // assume that CodeShrinker will be run after this - } + @Override + public String getName() { + return "RegionMakerVisitor"; } } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/ExcHandlersRegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/ExcHandlersRegionMaker.java new file mode 100644 index 000000000..8d74dc0a5 --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/ExcHandlersRegionMaker.java @@ -0,0 +1,153 @@ +package jadx.core.dex.visitors.regions.maker; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.jetbrains.annotations.Nullable; + +import jadx.core.dex.attributes.AFlag; +import jadx.core.dex.attributes.AType; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.IBlock; +import jadx.core.dex.nodes.IContainer; +import jadx.core.dex.nodes.IRegion; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.regions.Region; +import jadx.core.dex.trycatch.ExcHandlerAttr; +import jadx.core.dex.trycatch.ExceptionHandler; +import jadx.core.dex.trycatch.TryCatchBlockAttr; +import jadx.core.utils.BlockUtils; +import jadx.core.utils.RegionUtils; + +public class ExcHandlersRegionMaker { + private final MethodNode mth; + private final RegionMaker regionMaker; + + public ExcHandlersRegionMaker(MethodNode mth, RegionMaker regionMaker) { + this.mth = mth; + this.regionMaker = regionMaker; + } + + public void process() { + if (mth.isNoExceptionHandlers()) { + return; + } + IRegion excOutBlock = collectHandlerRegions(); + if (excOutBlock != null) { + mth.getRegion().add(excOutBlock); + } + } + + private @Nullable IRegion collectHandlerRegions() { + List tcs = mth.getAll(AType.TRY_BLOCKS_LIST); + for (TryCatchBlockAttr tc : tcs) { + List blocks = new ArrayList<>(tc.getHandlersCount()); + Set splitters = new HashSet<>(); + for (ExceptionHandler handler : tc.getHandlers()) { + BlockNode handlerBlock = handler.getHandlerBlock(); + if (handlerBlock != null) { + blocks.add(handlerBlock); + splitters.add(BlockUtils.getTopSplitterForHandler(handlerBlock)); + } else { + mth.addDebugComment("No exception handler block: " + handler); + } + } + Set exits = new HashSet<>(); + for (BlockNode splitter : splitters) { + for (BlockNode handler : blocks) { + if (handler.contains(AFlag.REMOVE)) { + continue; + } + List s = splitter.getSuccessors(); + if (s.isEmpty()) { + mth.addDebugComment("No successors for splitter: " + splitter); + continue; + } + BlockNode ss = s.get(0); + BlockNode cross = BlockUtils.getPathCross(mth, ss, handler); + if (cross != null && cross != ss && cross != handler) { + exits.add(cross); + } + } + } + for (ExceptionHandler handler : tc.getHandlers()) { + processExcHandler(handler, exits); + } + } + return processHandlersOutBlocks(tcs); + } + + /** + * Search handlers successor blocks aren't included in any region. + */ + private @Nullable IRegion processHandlersOutBlocks(List tcs) { + Set allRegionBlocks = new HashSet<>(); + RegionUtils.getAllRegionBlocks(mth.getRegion(), allRegionBlocks); + + Set successorBlocks = new HashSet<>(); + for (TryCatchBlockAttr tc : tcs) { + for (ExceptionHandler handler : tc.getHandlers()) { + IContainer region = handler.getHandlerRegion(); + if (region != null) { + IBlock lastBlock = RegionUtils.getLastBlock(region); + if (lastBlock instanceof BlockNode) { + successorBlocks.addAll(((BlockNode) lastBlock).getSuccessors()); + } + RegionUtils.getAllRegionBlocks(region, allRegionBlocks); + } + } + } + successorBlocks.removeAll(allRegionBlocks); + if (successorBlocks.isEmpty()) { + return null; + } + RegionStack stack = regionMaker.getStack(); + Region excOutRegion = new Region(mth.getRegion()); + for (IBlock block : successorBlocks) { + if (block instanceof BlockNode) { + stack.clear(); + stack.push(excOutRegion); + excOutRegion.add(regionMaker.makeRegion((BlockNode) block)); + } + } + return excOutRegion; + } + + private void processExcHandler(ExceptionHandler handler, Set exits) { + BlockNode start = handler.getHandlerBlock(); + if (start == null) { + return; + } + RegionStack stack = regionMaker.getStack().clear(); + BlockNode dom; + if (handler.isFinally()) { + dom = BlockUtils.getTopSplitterForHandler(start); + } else { + dom = start; + stack.addExits(exits); + } + if (dom.contains(AFlag.REMOVE)) { + return; + } + BitSet domFrontier = dom.getDomFrontier(); + List handlerExits = BlockUtils.bitSetToBlocks(mth, domFrontier); + boolean inLoop = mth.getLoopForBlock(start) != null; + for (BlockNode exit : handlerExits) { + if ((!inLoop || BlockUtils.isPathExists(start, exit)) + && RegionUtils.isRegionContainsBlock(mth.getRegion(), exit)) { + stack.addExit(exit); + } + } + handler.setHandlerRegion(regionMaker.makeRegion(start)); + + ExcHandlerAttr excHandlerAttr = start.get(AType.EXC_HANDLER); + if (excHandlerAttr == null) { + mth.addWarn("Missing exception handler attribute for start block: " + start); + } else { + handler.getHandlerRegion().addAttr(excHandlerAttr); + } + } +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/IfMakerHelper.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/IfRegionMaker.java similarity index 77% rename from jadx-core/src/main/java/jadx/core/dex/visitors/regions/IfMakerHelper.java rename to jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/IfRegionMaker.java index d0b391263..7711d870f 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/IfMakerHelper.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/IfRegionMaker.java @@ -1,10 +1,11 @@ -package jadx.core.dex.visitors.regions; +package jadx.core.dex.visitors.regions.maker; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Set; +import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -12,28 +13,133 @@ import org.slf4j.LoggerFactory; import jadx.core.Consts; import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AType; +import jadx.core.dex.attributes.nodes.EdgeInsnAttr; import jadx.core.dex.attributes.nodes.LoopInfo; import jadx.core.dex.instructions.IfNode; import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.regions.Region; import jadx.core.dex.regions.conditions.IfCondition; -import jadx.core.dex.regions.conditions.IfCondition.Mode; import jadx.core.dex.regions.conditions.IfInfo; +import jadx.core.dex.regions.conditions.IfRegion; +import jadx.core.dex.regions.loops.LoopRegion; import jadx.core.utils.BlockUtils; import jadx.core.utils.exceptions.JadxRuntimeException; -import static jadx.core.dex.visitors.regions.RegionMaker.isEqualPaths; -import static jadx.core.dex.visitors.regions.RegionMaker.isEqualReturnBlocks; +import static jadx.core.utils.BlockUtils.isEqualPaths; +import static jadx.core.utils.BlockUtils.isEqualReturnBlocks; import static jadx.core.utils.BlockUtils.isPathExists; -public class IfMakerHelper { - private static final Logger LOG = LoggerFactory.getLogger(IfMakerHelper.class); +final class IfRegionMaker { + private static final Logger LOG = LoggerFactory.getLogger(IfRegionMaker.class); + private final MethodNode mth; + private final RegionMaker regionMaker; - private IfMakerHelper() { + IfRegionMaker(MethodNode mth, RegionMaker regionMaker) { + this.mth = mth; + this.regionMaker = regionMaker; + } + + BlockNode process(IRegion currentRegion, BlockNode block, IfNode ifnode, RegionStack stack) { + if (block.contains(AFlag.ADDED_TO_REGION)) { + // block already included in other 'if' region + return ifnode.getThenBlock(); + } + + IfInfo currentIf = makeIfInfo(mth, block); + if (currentIf == null) { + return null; + } + IfInfo mergedIf = mergeNestedIfNodes(currentIf); + if (mergedIf != null) { + currentIf = mergedIf; + } else { + // invert simple condition (compiler often do it) + currentIf = IfInfo.invert(currentIf); + } + IfInfo modifiedIf = restructureIf(mth, block, currentIf); + if (modifiedIf != null) { + currentIf = modifiedIf; + } else { + if (currentIf.getMergedBlocks().size() <= 1) { + return null; + } + currentIf = makeIfInfo(mth, block); + currentIf = restructureIf(mth, block, currentIf); + if (currentIf == null) { + // all attempts failed + return null; + } + } + confirmMerge(currentIf); + + IfRegion ifRegion = new IfRegion(currentRegion); + ifRegion.updateCondition(currentIf); + currentRegion.getSubBlocks().add(ifRegion); + + BlockNode outBlock = currentIf.getOutBlock(); + stack.push(ifRegion); + stack.addExit(outBlock); + + BlockNode thenBlock = currentIf.getThenBlock(); + if (thenBlock == null) { + // empty then block, not normal, but maybe correct + ifRegion.setThenRegion(new Region(ifRegion)); + } else { + ifRegion.setThenRegion(regionMaker.makeRegion(thenBlock)); + } + BlockNode elseBlock = currentIf.getElseBlock(); + if (elseBlock == null || stack.containsExit(elseBlock)) { + ifRegion.setElseRegion(null); + } else { + ifRegion.setElseRegion(regionMaker.makeRegion(elseBlock)); + } + + // insert edge insns in new 'else' branch + // TODO: make more common algorithm + if (ifRegion.getElseRegion() == null && outBlock != null) { + List edgeInsnAttrs = outBlock.getAll(AType.EDGE_INSN); + if (!edgeInsnAttrs.isEmpty()) { + Region elseRegion = new Region(ifRegion); + for (EdgeInsnAttr edgeInsnAttr : edgeInsnAttrs) { + if (edgeInsnAttr.getEnd().equals(outBlock)) { + addEdgeInsn(currentIf, elseRegion, edgeInsnAttr); + } + } + ifRegion.setElseRegion(elseRegion); + } + } + + stack.pop(); + return outBlock; + } + + @NotNull + IfInfo buildIfInfo(LoopRegion loopRegion) { + IfInfo condInfo = makeIfInfo(mth, loopRegion.getHeader()); + condInfo = searchNestedIf(condInfo); + confirmMerge(condInfo); + return condInfo; + } + + private void addEdgeInsn(IfInfo ifInfo, Region region, EdgeInsnAttr edgeInsnAttr) { + BlockNode start = edgeInsnAttr.getStart(); + boolean fromThisIf = false; + for (BlockNode ifBlock : ifInfo.getMergedBlocks()) { + if (ifBlock.getSuccessors().contains(start)) { + fromThisIf = true; + break; + } + } + if (!fromThisIf) { + return; + } + region.add(start); } @Nullable @@ -262,7 +368,7 @@ public class IfMakerHelper { return from.getCleanSuccessors().size() == 1 && from.getCleanSuccessors().contains(to); } - private static IfInfo mergeIfInfo(IfInfo first, IfInfo second, boolean followThenBranch) { + static IfInfo mergeIfInfo(IfInfo first, IfInfo second, boolean followThenBranch) { MethodNode mth = first.getMth(); Set skipBlocks = first.getSkipBlocks(); BlockNode thenBlock; @@ -274,7 +380,7 @@ public class IfMakerHelper { thenBlock = getBranchBlock(first.getThenBlock(), second.getThenBlock(), skipBlocks, mth); elseBlock = second.getElseBlock(); } - Mode mergeOperation = followThenBranch ? Mode.AND : Mode.OR; + IfCondition.Mode mergeOperation = followThenBranch ? IfCondition.Mode.AND : IfCondition.Mode.OR; IfCondition condition = IfCondition.merge(mergeOperation, first.getCondition(), second.getCondition()); IfInfo result = new IfInfo(mth, condition, thenBlock, elseBlock); result.merge(first, second); diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/LoopRegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/LoopRegionMaker.java new file mode 100644 index 000000000..5ce38ac86 --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/LoopRegionMaker.java @@ -0,0 +1,464 @@ +package jadx.core.dex.visitors.regions.maker; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import jadx.core.dex.attributes.AFlag; +import jadx.core.dex.attributes.AType; +import jadx.core.dex.attributes.nodes.EdgeInsnAttr; +import jadx.core.dex.attributes.nodes.LoopInfo; +import jadx.core.dex.attributes.nodes.LoopLabelAttr; +import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.Edge; +import jadx.core.dex.nodes.IRegion; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.regions.Region; +import jadx.core.dex.regions.conditions.IfInfo; +import jadx.core.dex.regions.loops.LoopRegion; +import jadx.core.dex.trycatch.ExceptionHandler; +import jadx.core.utils.BlockUtils; +import jadx.core.utils.ListUtils; +import jadx.core.utils.RegionUtils; +import jadx.core.utils.exceptions.JadxRuntimeException; + +import static jadx.core.utils.BlockUtils.getNextBlock; +import static jadx.core.utils.BlockUtils.isPathExists; + +final class LoopRegionMaker { + private final MethodNode mth; + private final RegionMaker regionMaker; + private final IfRegionMaker ifMaker; + + LoopRegionMaker(MethodNode mth, RegionMaker regionMaker, IfRegionMaker ifMaker) { + this.mth = mth; + this.regionMaker = regionMaker; + this.ifMaker = ifMaker; + } + + BlockNode process(IRegion curRegion, LoopInfo loop, RegionStack stack) { + BlockNode loopStart = loop.getStart(); + Set exitBlocksSet = loop.getExitNodes(); + + // set exit blocks scan order priority + // this can help if loop has several exits (after using 'break' or 'return' in loop) + List exitBlocks = new ArrayList<>(exitBlocksSet.size()); + BlockNode nextStart = getNextBlock(loopStart); + if (nextStart != null && exitBlocksSet.remove(nextStart)) { + exitBlocks.add(nextStart); + } + if (exitBlocksSet.remove(loopStart)) { + exitBlocks.add(loopStart); + } + if (exitBlocksSet.remove(loop.getEnd())) { + exitBlocks.add(loop.getEnd()); + } + exitBlocks.addAll(exitBlocksSet); + + LoopRegion loopRegion = makeLoopRegion(curRegion, loop, exitBlocks); + if (loopRegion == null) { + BlockNode exit = makeEndlessLoop(curRegion, stack, loop, loopStart); + insertContinue(loop); + return exit; + } + curRegion.getSubBlocks().add(loopRegion); + IRegion outerRegion = stack.peekRegion(); + stack.push(loopRegion); + + IfInfo condInfo = ifMaker.buildIfInfo(loopRegion); + if (!loop.getLoopBlocks().contains(condInfo.getThenBlock())) { + // invert loop condition if 'then' points to exit + condInfo = IfInfo.invert(condInfo); + } + loopRegion.updateCondition(condInfo); + // prevent if's merge with loop condition + condInfo.getMergedBlocks().forEach(b -> b.add(AFlag.ADDED_TO_REGION)); + exitBlocks.removeAll(condInfo.getMergedBlocks()); + + if (!exitBlocks.isEmpty()) { + BlockNode loopExit = condInfo.getElseBlock(); + if (loopExit != null) { + // add 'break' instruction before path cross between main loop exit and sub-exit + for (Edge exitEdge : loop.getExitEdges()) { + if (exitBlocks.contains(exitEdge.getSource())) { + insertLoopBreak(stack, loop, loopExit, exitEdge); + } + } + } + } + + BlockNode out; + if (loopRegion.isConditionAtEnd()) { + BlockNode thenBlock = condInfo.getThenBlock(); + out = thenBlock == loop.getEnd() || thenBlock == loopStart ? condInfo.getElseBlock() : thenBlock; + out = BlockUtils.followEmptyPath(out); + loopStart.remove(AType.LOOP); + loop.getEnd().add(AFlag.ADDED_TO_REGION); + stack.addExit(loop.getEnd()); + regionMaker.clearBlockProcessedState(loopStart); + Region body = regionMaker.makeRegion(loopStart); + loopRegion.setBody(body); + loopStart.addAttr(AType.LOOP, loop); + loop.getEnd().remove(AFlag.ADDED_TO_REGION); + } else { + out = condInfo.getElseBlock(); + if (outerRegion != null + && out != null + && out.contains(AFlag.LOOP_START) + && !out.getAll(AType.LOOP).contains(loop) + && RegionUtils.isRegionContainsBlock(outerRegion, out)) { + // exit to already processed outer loop + out = null; + } + stack.addExit(out); + BlockNode loopBody = condInfo.getThenBlock(); + Region body; + if (Objects.equals(loopBody, loopStart)) { + // empty loop body + body = new Region(loopRegion); + } else { + body = regionMaker.makeRegion(loopBody); + } + // add blocks from loop start to first condition block + BlockNode conditionBlock = condInfo.getFirstIfBlock(); + if (loopStart != conditionBlock) { + Set blocks = BlockUtils.getAllPathsBlocks(loopStart, conditionBlock); + blocks.remove(conditionBlock); + for (BlockNode block : blocks) { + if (block.getInstructions().isEmpty() + && !block.contains(AFlag.ADDED_TO_REGION) + && !RegionUtils.isRegionContainsBlock(body, block)) { + body.add(block); + } + } + } + loopRegion.setBody(body); + } + stack.pop(); + insertContinue(loop); + return out; + } + + /** + * Select loop exit and construct LoopRegion + */ + private LoopRegion makeLoopRegion(IRegion curRegion, LoopInfo loop, List exitBlocks) { + for (BlockNode block : exitBlocks) { + if (block.contains(AType.EXC_HANDLER)) { + continue; + } + InsnNode lastInsn = BlockUtils.getLastInsn(block); + if (lastInsn == null || lastInsn.getType() != InsnType.IF) { + continue; + } + List loops = block.getAll(AType.LOOP); + if (!loops.isEmpty() && loops.get(0) != loop) { + // skip nested loop condition + continue; + } + boolean exitAtLoopEnd = isExitAtLoopEnd(block, loop); + LoopRegion loopRegion = new LoopRegion(curRegion, loop, block, exitAtLoopEnd); + boolean found; + if (block == loop.getStart() || exitAtLoopEnd + || BlockUtils.isEmptySimplePath(loop.getStart(), block)) { + found = true; + } else if (block.getPredecessors().contains(loop.getStart())) { + loopRegion.setPreCondition(loop.getStart()); + // if we can't merge pre-condition this is not correct header + found = loopRegion.checkPreCondition(); + } else { + found = false; + } + if (found) { + List list = mth.getAllLoopsForBlock(block); + if (list.size() >= 2) { + // bad condition if successors going out of all loops + boolean allOuter = true; + for (BlockNode outerBlock : block.getCleanSuccessors()) { + List outLoopList = mth.getAllLoopsForBlock(outerBlock); + outLoopList.remove(loop); + if (!outLoopList.isEmpty()) { + // goes to outer loop + allOuter = false; + break; + } + } + if (allOuter) { + found = false; + } + } + } + if (found && !checkLoopExits(loop, block)) { + found = false; + } + if (found) { + return loopRegion; + } + } + // no exit found => endless loop + return null; + } + + private static boolean isExitAtLoopEnd(BlockNode exit, LoopInfo loop) { + BlockNode loopEnd = loop.getEnd(); + if (exit == loopEnd) { + return true; + } + BlockNode loopStart = loop.getStart(); + if (loopStart.getInstructions().isEmpty() && ListUtils.isSingleElement(loopStart.getSuccessors(), exit)) { + return false; + } + return loopEnd.getInstructions().isEmpty() && ListUtils.isSingleElement(loopEnd.getPredecessors(), exit); + } + + private boolean checkLoopExits(LoopInfo loop, BlockNode mainExitBlock) { + List exitEdges = loop.getExitEdges(); + if (exitEdges.size() < 2) { + return true; + } + Optional mainEdgeOpt = exitEdges.stream().filter(edge -> edge.getSource() == mainExitBlock).findFirst(); + if (mainEdgeOpt.isEmpty()) { + throw new JadxRuntimeException("Not found exit edge by exit block: " + mainExitBlock); + } + Edge mainExitEdge = mainEdgeOpt.get(); + BlockNode mainOutBlock = mainExitEdge.getTarget(); + for (Edge exitEdge : exitEdges) { + if (exitEdge != mainExitEdge) { + // all exit paths must be same or don't cross (will be inside loop) + BlockNode exitBlock = exitEdge.getTarget(); + if (!BlockUtils.isEqualPaths(mainOutBlock, exitBlock)) { + BlockNode crossBlock = BlockUtils.getPathCross(mth, mainOutBlock, exitBlock); + if (crossBlock != null) { + return false; + } + } + } + } + return true; + } + + private BlockNode makeEndlessLoop(IRegion curRegion, RegionStack stack, LoopInfo loop, BlockNode loopStart) { + LoopRegion loopRegion = new LoopRegion(curRegion, loop, null, false); + curRegion.getSubBlocks().add(loopRegion); + + loopStart.remove(AType.LOOP); + regionMaker.clearBlockProcessedState(loopStart); + stack.push(loopRegion); + + BlockNode out = null; + // insert 'break' for exits + List exitEdges = loop.getExitEdges(); + if (exitEdges.size() == 1) { + Edge exitEdge = exitEdges.get(0); + BlockNode exit = exitEdge.getTarget(); + if (insertLoopBreak(stack, loop, exit, exitEdge)) { + BlockNode nextBlock = getNextBlock(exit); + if (nextBlock != null) { + stack.addExit(nextBlock); + out = nextBlock; + } + } + } else { + for (Edge exitEdge : exitEdges) { + BlockNode exit = exitEdge.getTarget(); + List blocks = BlockUtils.bitSetToBlocks(mth, exit.getDomFrontier()); + for (BlockNode block : blocks) { + if (BlockUtils.isPathExists(exit, block)) { + stack.addExit(block); + insertLoopBreak(stack, loop, block, exitEdge); + out = block; + } else { + insertLoopBreak(stack, loop, exit, exitEdge); + } + } + } + } + + Region body = regionMaker.makeRegion(loopStart); + BlockNode loopEnd = loop.getEnd(); + if (!RegionUtils.isRegionContainsBlock(body, loopEnd) + && !loopEnd.contains(AType.EXC_HANDLER) + && !inExceptionHandlerBlocks(loopEnd)) { + body.getSubBlocks().add(loopEnd); + } + loopRegion.setBody(body); + + if (out == null) { + BlockNode next = getNextBlock(loopEnd); + out = RegionUtils.isRegionContainsBlock(body, next) ? null : next; + } + stack.pop(); + loopStart.addAttr(AType.LOOP, loop); + return out; + } + + private boolean inExceptionHandlerBlocks(BlockNode loopEnd) { + if (mth.getExceptionHandlersCount() == 0) { + return false; + } + for (ExceptionHandler eh : mth.getExceptionHandlers()) { + if (eh.getBlocks().contains(loopEnd)) { + return true; + } + } + return false; + } + + private boolean canInsertBreak(BlockNode exit) { + if (BlockUtils.containsExitInsn(exit)) { + return false; + } + List simplePath = BlockUtils.buildSimplePath(exit); + if (!simplePath.isEmpty()) { + BlockNode lastBlock = simplePath.get(simplePath.size() - 1); + if (lastBlock.isMthExitBlock() + || lastBlock.isReturnBlock() + || mth.isPreExitBlock(lastBlock)) { + return false; + } + } + // check if there no outer switch (TODO: very expensive check) + Set paths = BlockUtils.getAllPathsBlocks(mth.getEnterBlock(), exit); + for (BlockNode block : paths) { + if (BlockUtils.checkLastInsnType(block, InsnType.SWITCH)) { + return false; + } + } + return true; + } + + private boolean insertLoopBreak(RegionStack stack, LoopInfo loop, BlockNode loopExit, Edge exitEdge) { + BlockNode exit = exitEdge.getTarget(); + Edge insertEdge = null; + boolean confirm = false; + // process special cases: + // 1. jump to outer loop + BlockNode exitEnd = BlockUtils.followEmptyPath(exit); + List loops = exitEnd.getAll(AType.LOOP); + for (LoopInfo loopAtEnd : loops) { + if (loopAtEnd != loop && loop.hasParent(loopAtEnd)) { + insertEdge = exitEdge; + confirm = true; + break; + } + } + + if (!confirm) { + BlockNode insertBlock = null; + while (exit != null) { + if (insertBlock != null && isPathExists(loopExit, exit)) { + // found cross + if (canInsertBreak(insertBlock)) { + insertEdge = new Edge(insertBlock, insertBlock.getSuccessors().get(0)); + confirm = true; + break; + } + return false; + } + insertBlock = exit; + List cs = exit.getCleanSuccessors(); + exit = cs.size() == 1 ? cs.get(0) : null; + } + } + if (!confirm) { + return false; + } + InsnNode breakInsn = new InsnNode(InsnType.BREAK, 0); + breakInsn.addAttr(AType.LOOP, loop); + EdgeInsnAttr.addEdgeInsn(insertEdge, breakInsn); + stack.addExit(exit); + // add label to 'break' if needed + addBreakLabel(exitEdge, exit, breakInsn); + return true; + } + + private void addBreakLabel(Edge exitEdge, BlockNode exit, InsnNode breakInsn) { + BlockNode outBlock = BlockUtils.getNextBlock(exitEdge.getTarget()); + if (outBlock == null) { + return; + } + List exitLoop = mth.getAllLoopsForBlock(outBlock); + if (!exitLoop.isEmpty()) { + return; + } + List inLoops = mth.getAllLoopsForBlock(exitEdge.getSource()); + if (inLoops.size() < 2) { + return; + } + // search for parent loop + LoopInfo parentLoop = null; + for (LoopInfo loop : inLoops) { + if (loop.getParentLoop() == null) { + parentLoop = loop; + break; + } + } + if (parentLoop == null) { + return; + } + if (parentLoop.getEnd() != exit && !parentLoop.getExitNodes().contains(exit)) { + LoopLabelAttr labelAttr = new LoopLabelAttr(parentLoop); + breakInsn.addAttr(labelAttr); + parentLoop.getStart().addAttr(labelAttr); + } + } + + private static void insertContinue(LoopInfo loop) { + BlockNode loopEnd = loop.getEnd(); + List predecessors = loopEnd.getPredecessors(); + if (predecessors.size() <= 1) { + return; + } + Set loopExitNodes = loop.getExitNodes(); + for (BlockNode pred : predecessors) { + if (canInsertContinue(pred, predecessors, loopEnd, loopExitNodes)) { + InsnNode cont = new InsnNode(InsnType.CONTINUE, 0); + pred.getInstructions().add(cont); + } + } + } + + private static boolean canInsertContinue(BlockNode pred, List predecessors, BlockNode loopEnd, + Set loopExitNodes) { + if (!pred.contains(AFlag.SYNTHETIC) + || BlockUtils.checkLastInsnType(pred, InsnType.CONTINUE)) { + return false; + } + List preds = pred.getPredecessors(); + if (preds.isEmpty()) { + return false; + } + BlockNode codePred = preds.get(0); + if (codePred.contains(AFlag.ADDED_TO_REGION)) { + return false; + } + if (loopEnd.isDominator(codePred) + || loopExitNodes.contains(codePred)) { + return false; + } + if (isDominatedOnBlocks(codePred, predecessors)) { + return false; + } + boolean gotoExit = false; + for (BlockNode exit : loopExitNodes) { + if (BlockUtils.isPathExists(codePred, exit)) { + gotoExit = true; + break; + } + } + return gotoExit; + } + + private static boolean isDominatedOnBlocks(BlockNode dom, List blocks) { + for (BlockNode node : blocks) { + if (!node.isDominator(dom)) { + return false; + } + } + return true; + } +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionMaker.java new file mode 100644 index 000000000..aadd88add --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionMaker.java @@ -0,0 +1,168 @@ +package jadx.core.dex.visitors.regions.maker; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.List; +import java.util.Objects; + +import jadx.core.dex.attributes.AFlag; +import jadx.core.dex.attributes.AType; +import jadx.core.dex.attributes.nodes.EdgeInsnAttr; +import jadx.core.dex.attributes.nodes.LoopInfo; +import jadx.core.dex.instructions.IfNode; +import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.instructions.SwitchInsn; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.IRegion; +import jadx.core.dex.nodes.InsnContainer; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.regions.Region; +import jadx.core.utils.BlockUtils; +import jadx.core.utils.exceptions.JadxOverflowException; + +import static jadx.core.utils.BlockUtils.getNextBlock; + +public class RegionMaker { + private final MethodNode mth; + private final RegionStack stack; + + private final IfRegionMaker ifMaker; + private final LoopRegionMaker loopMaker; + + private final BitSet processedBlocks; + private final int regionsLimit; + + private int regionsCount; + + public RegionMaker(MethodNode mth) { + this.mth = mth; + this.stack = new RegionStack(mth); + this.ifMaker = new IfRegionMaker(mth, this); + this.loopMaker = new LoopRegionMaker(mth, this, ifMaker); + int blocksCount = mth.getBasicBlocks().size(); + this.processedBlocks = new BitSet(blocksCount); + this.regionsLimit = blocksCount * 100; + } + + public Region makeMthRegion() { + return makeRegion(mth.getEnterBlock()); + } + + Region makeRegion(BlockNode startBlock) { + Objects.requireNonNull(startBlock); + Region region = new Region(stack.peekRegion()); + if (stack.containsExit(startBlock)) { + insertEdgeInsns(region, startBlock); + return region; + } + + int startBlockId = startBlock.getId(); + if (processedBlocks.get(startBlockId)) { + mth.addWarn("Removed duplicated region for block: " + startBlock + ' ' + startBlock.getAttributesString()); + return region; + } + processedBlocks.set(startBlockId); + + BlockNode next = startBlock; + while (next != null) { + next = traverse(region, next); + regionsCount++; + if (regionsCount > regionsLimit) { + throw new JadxOverflowException("Regions count limit reached"); + } + } + return region; + } + + /** + * Recursively traverse all blocks from 'block' until block from 'exits' + */ + private BlockNode traverse(IRegion r, BlockNode block) { + if (block.contains(AFlag.MTH_EXIT_BLOCK)) { + return null; + } + BlockNode next = null; + boolean processed = false; + + List loops = block.getAll(AType.LOOP); + int loopCount = loops.size(); + if (loopCount != 0 && block.contains(AFlag.LOOP_START)) { + if (loopCount == 1) { + next = loopMaker.process(r, loops.get(0), stack); + processed = true; + } else { + for (LoopInfo loop : loops) { + if (loop.getStart() == block) { + next = loopMaker.process(r, loop, stack); + processed = true; + break; + } + } + } + } + + InsnNode insn = BlockUtils.getLastInsn(block); + if (!processed && insn != null) { + switch (insn.getType()) { + case IF: + next = ifMaker.process(r, block, (IfNode) insn, stack); + processed = true; + break; + + case SWITCH: + SwitchRegionMaker switchMaker = new SwitchRegionMaker(mth, this); + next = switchMaker.process(r, block, (SwitchInsn) insn, stack); + processed = true; + break; + + case MONITOR_ENTER: + SynchronizedRegionMaker syncMaker = new SynchronizedRegionMaker(mth, this); + next = syncMaker.process(r, block, insn, stack); + processed = true; + break; + } + } + if (!processed) { + r.getSubBlocks().add(block); + next = getNextBlock(block); + } + if (next != null && !stack.containsExit(block) && !stack.containsExit(next)) { + return next; + } + return null; + } + + private void insertEdgeInsns(Region region, BlockNode exitBlock) { + List edgeInsns = exitBlock.getAll(AType.EDGE_INSN); + if (edgeInsns.isEmpty()) { + return; + } + List insns = new ArrayList<>(edgeInsns.size()); + addOneInsnOfType(insns, edgeInsns, InsnType.BREAK); + addOneInsnOfType(insns, edgeInsns, InsnType.CONTINUE); + region.add(new InsnContainer(insns)); + } + + private void addOneInsnOfType(List insns, List edgeInsns, InsnType insnType) { + for (EdgeInsnAttr edgeInsn : edgeInsns) { + InsnNode insn = edgeInsn.getInsn(); + if (insn.getType() == insnType) { + insns.add(insn); + return; + } + } + } + + RegionStack getStack() { + return stack; + } + + boolean isProcessed(BlockNode block) { + return processedBlocks.get(block.getId()); + } + + void clearBlockProcessedState(BlockNode block) { + processedBlocks.clear(block.getId()); + } +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionStack.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionStack.java similarity index 93% rename from jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionStack.java rename to jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionStack.java index 4af41396b..ed717f9a0 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionStack.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionStack.java @@ -1,4 +1,4 @@ -package jadx.core.dex.visitors.regions; +package jadx.core.dex.visitors.regions.maker; import java.util.ArrayDeque; import java.util.Collection; @@ -31,7 +31,7 @@ final class RegionStack { IRegion region; public State() { - exits = new HashSet<>(4); + exits = new HashSet<>(); } private State(State c, IRegion region) { @@ -113,6 +113,12 @@ final class RegionStack { return stack.size(); } + public RegionStack clear() { + stack.clear(); + curState = new State(); + return this; + } + @Override public String toString() { return "Region stack size: " + size() + ", last: " + curState; diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java new file mode 100644 index 000000000..b6620c1c9 --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java @@ -0,0 +1,288 @@ +package jadx.core.dex.visitors.regions.maker; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.jetbrains.annotations.Nullable; + +import jadx.core.dex.attributes.AFlag; +import jadx.core.dex.attributes.nodes.LoopInfo; +import jadx.core.dex.attributes.nodes.RegionRefAttr; +import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.instructions.SwitchInsn; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.IRegion; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.regions.Region; +import jadx.core.dex.regions.SwitchRegion; +import jadx.core.utils.BlockUtils; +import jadx.core.utils.RegionUtils; +import jadx.core.utils.Utils; +import jadx.core.utils.exceptions.JadxRuntimeException; + +final class SwitchRegionMaker { + private final MethodNode mth; + private final RegionMaker regionMaker; + + SwitchRegionMaker(MethodNode mth, RegionMaker regionMaker) { + this.mth = mth; + this.regionMaker = regionMaker; + } + + BlockNode process(IRegion currentRegion, BlockNode block, SwitchInsn insn, RegionStack stack) { + // map case blocks to keys + int len = insn.getTargets().length; + Map> blocksMap = new LinkedHashMap<>(len); + BlockNode[] targetBlocksArr = insn.getTargetBlocks(); + for (int i = 0; i < len; i++) { + List keys = blocksMap.computeIfAbsent(targetBlocksArr[i], k -> new ArrayList<>(2)); + keys.add(insn.getKey(i)); + } + BlockNode defCase = insn.getDefTargetBlock(); + if (defCase != null) { + List keys = blocksMap.computeIfAbsent(defCase, k -> new ArrayList<>(1)); + keys.add(SwitchRegion.DEFAULT_CASE_KEY); + } + + SwitchRegion sw = new SwitchRegion(currentRegion, block); + insn.addAttr(new RegionRefAttr(sw)); + currentRegion.getSubBlocks().add(sw); + stack.push(sw); + + BlockNode out = calcSwitchOut(block, stack); + stack.addExit(out); + + processFallThroughCases(sw, out, stack, blocksMap); + removeEmptyCases(insn, sw, defCase); + + stack.pop(); + return out; + } + + private void processFallThroughCases(SwitchRegion sw, @Nullable BlockNode out, + RegionStack stack, Map> blocksMap) { + Map fallThroughCases = new LinkedHashMap<>(); + if (out != null) { + // detect fallthrough cases + BitSet caseBlocks = BlockUtils.blocksToBitSet(mth, blocksMap.keySet()); + caseBlocks.clear(out.getId()); + for (BlockNode successor : sw.getHeader().getCleanSuccessors()) { + BitSet df = successor.getDomFrontier(); + if (df.intersects(caseBlocks)) { + BlockNode fallThroughBlock = getOneIntersectionBlock(out, caseBlocks, df); + fallThroughCases.put(successor, fallThroughBlock); + } + } + // check fallthrough cases order + if (!fallThroughCases.isEmpty() && isBadCasesOrder(blocksMap, fallThroughCases)) { + Map> newBlocksMap = reOrderSwitchCases(blocksMap, fallThroughCases); + if (isBadCasesOrder(newBlocksMap, fallThroughCases)) { + mth.addWarnComment("Can't fix incorrect switch cases order, some code will duplicate"); + fallThroughCases.clear(); + } else { + blocksMap = newBlocksMap; + } + } + } + + for (Map.Entry> entry : blocksMap.entrySet()) { + List keysList = entry.getValue(); + BlockNode caseBlock = entry.getKey(); + if (stack.containsExit(caseBlock)) { + sw.addCase(keysList, new Region(stack.peekRegion())); + } else { + BlockNode next = fallThroughCases.get(caseBlock); + stack.addExit(next); + Region caseRegion = regionMaker.makeRegion(caseBlock); + stack.removeExit(next); + if (next != null) { + next.add(AFlag.FALL_THROUGH); + caseRegion.add(AFlag.FALL_THROUGH); + } + sw.addCase(keysList, caseRegion); + // 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor + } + } + } + + @Nullable + private BlockNode getOneIntersectionBlock(BlockNode out, BitSet caseBlocks, BitSet fallThroughSet) { + BitSet caseExits = BlockUtils.copyBlocksBitSet(mth, fallThroughSet); + caseExits.clear(out.getId()); + caseExits.and(caseBlocks); + return BlockUtils.bitSetToOneBlock(mth, caseExits); + } + + private @Nullable BlockNode calcSwitchOut(BlockNode block, RegionStack stack) { + // union of case blocks dominance frontier + // works if no fallthrough cases and no returns inside switch + BitSet outs = BlockUtils.newBlocksBitSet(mth); + for (BlockNode s : block.getCleanSuccessors()) { + if (s.contains(AFlag.LOOP_END)) { + // loop end dom frontier is loop start, ignore it + continue; + } + outs.or(s.getDomFrontier()); + } + outs.clear(block.getId()); + outs.clear(mth.getExitBlock().getId()); + if (outs.isEmpty()) { + // switch already contains method exit + // add everything, out block not needed + return mth.getExitBlock(); + } + + BlockNode out = null; + if (outs.cardinality() == 1) { + // single exit + out = BlockUtils.bitSetToOneBlock(mth, outs); + } else { + // several switch exits + // possible 'return', 'continue' or fallthrough in one of the cases + LoopInfo loop = mth.getLoopForBlock(block); + if (loop != null) { + outs.andNot(loop.getStart().getPostDoms()); + outs.andNot(loop.getEnd().getPostDoms()); + BlockNode loopEnd = loop.getEnd(); + if (outs.cardinality() == 2 && outs.get(loopEnd.getId())) { + // insert 'continue' for cases lead to loop end + // expect only 2 exits: loop end and switch out + List outList = BlockUtils.bitSetToBlocks(mth, outs); + outList.remove(loopEnd); + BlockNode possibleOut = Utils.getOne(outList); + if (possibleOut != null && insertContinueInSwitch(block, possibleOut, loopEnd)) { + outs.clear(loopEnd.getId()); + out = possibleOut; + } + } + } + if (out == null) { + BlockNode imPostDom = block.getIPostDom(); + if (outs.get(imPostDom.getId())) { + out = imPostDom; + } else { + outs.andNot(block.getPostDoms()); + out = BlockUtils.bitSetToOneBlock(mth, outs); + } + } + } + if (out != null && mth.isPreExitBlock(out)) { + // include 'return' or 'throw' in case blocks + out = mth.getExitBlock(); + } + BlockNode imPostDom = block.getIPostDom(); + if (out != imPostDom && !mth.isPreExitBlock(imPostDom)) { + // stop other paths at common exit + stack.addExit(imPostDom); + } + if (block.getCleanSuccessors().contains(imPostDom)) { + // add exit to stop on empty 'default' block + stack.addExit(imPostDom); + } + if (out == null) { + mth.addWarnComment("Failed to find 'out' block for switch in " + block + ". Please report as an issue."); + // fallback option; should work in most cases + out = block.getIPostDom(); + } + if (out != null && regionMaker.isProcessed(out)) { + // 'out' block already processed, prevent endless loop + throw new JadxRuntimeException("Failed to find switch 'out' block (already processed)"); + } + return out; + } + + /** + * Remove empty case blocks: + * 1. single 'default' case + * 2. filler cases if switch is 'packed' and 'default' case is empty + */ + private void removeEmptyCases(SwitchInsn insn, SwitchRegion sw, BlockNode defCase) { + boolean defaultCaseIsEmpty; + if (defCase == null) { + defaultCaseIsEmpty = true; + } else { + defaultCaseIsEmpty = sw.getCases().stream() + .anyMatch(c -> c.getKeys().contains(SwitchRegion.DEFAULT_CASE_KEY) + && RegionUtils.isEmpty(c.getContainer())); + } + if (defaultCaseIsEmpty) { + sw.getCases().removeIf(caseInfo -> { + if (RegionUtils.isEmpty(caseInfo.getContainer())) { + List keys = caseInfo.getKeys(); + if (keys.contains(SwitchRegion.DEFAULT_CASE_KEY)) { + return true; + } + if (insn.isPacked()) { + return true; + } + } + return false; + }); + } + } + + private boolean isBadCasesOrder(Map> blocksMap, Map fallThroughCases) { + BlockNode nextCaseBlock = null; + for (BlockNode caseBlock : blocksMap.keySet()) { + if (nextCaseBlock != null && !caseBlock.equals(nextCaseBlock)) { + return true; + } + nextCaseBlock = fallThroughCases.get(caseBlock); + } + return nextCaseBlock != null; + } + + private Map> reOrderSwitchCases(Map> blocksMap, + Map fallThroughCases) { + List list = new ArrayList<>(blocksMap.size()); + list.addAll(blocksMap.keySet()); + list.sort((a, b) -> { + BlockNode nextA = fallThroughCases.get(a); + if (nextA != null) { + if (b.equals(nextA)) { + return -1; + } + } else if (a.equals(fallThroughCases.get(b))) { + return 1; + } + return 0; + }); + + Map> newBlocksMap = new LinkedHashMap<>(blocksMap.size()); + for (BlockNode key : list) { + newBlocksMap.put(key, blocksMap.get(key)); + } + return newBlocksMap; + } + + private boolean insertContinueInSwitch(BlockNode switchBlock, BlockNode switchOut, BlockNode loopEnd) { + boolean inserted = false; + for (BlockNode caseBlock : switchBlock.getCleanSuccessors()) { + if (caseBlock.getDomFrontier().get(loopEnd.getId()) && caseBlock != switchOut) { + // search predecessor of loop end on path from this successor + Set list = new HashSet<>(BlockUtils.collectBlocksDominatedBy(mth, caseBlock, caseBlock)); + if (list.contains(switchOut) || switchOut.getPredecessors().stream().anyMatch(list::contains)) { + // 'continue' not needed + } else { + for (BlockNode p : loopEnd.getPredecessors()) { + if (list.contains(p)) { + if (p.isSynthetic()) { + p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0)); + inserted = true; + } + break; + } + } + } + } + } + return inserted; + } + +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SynchronizedRegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SynchronizedRegionMaker.java new file mode 100644 index 000000000..97048534a --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SynchronizedRegionMaker.java @@ -0,0 +1,162 @@ +package jadx.core.dex.visitors.regions.maker; + +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import jadx.core.dex.attributes.AFlag; +import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.instructions.args.InsnArg; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.IContainer; +import jadx.core.dex.nodes.IRegion; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.regions.Region; +import jadx.core.dex.regions.SynchronizedRegion; +import jadx.core.dex.visitors.regions.CleanRegions; +import jadx.core.utils.BlockUtils; +import jadx.core.utils.InsnRemover; +import jadx.core.utils.Utils; + +import static jadx.core.utils.BlockUtils.getNextBlock; +import static jadx.core.utils.BlockUtils.isPathExists; + +public class SynchronizedRegionMaker { + private static final Logger LOG = LoggerFactory.getLogger(SynchronizedRegionMaker.class); + private final MethodNode mth; + private final RegionMaker regionMaker; + + SynchronizedRegionMaker(MethodNode mth, RegionMaker regionMaker) { + this.mth = mth; + this.regionMaker = regionMaker; + } + + BlockNode process(IRegion curRegion, BlockNode block, InsnNode insn, RegionStack stack) { + SynchronizedRegion synchRegion = new SynchronizedRegion(curRegion, insn); + synchRegion.getSubBlocks().add(block); + curRegion.getSubBlocks().add(synchRegion); + + Set exits = new LinkedHashSet<>(); + Set cacheSet = new HashSet<>(); + traverseMonitorExits(synchRegion, insn.getArg(0), block, exits, cacheSet); + + for (InsnNode exitInsn : synchRegion.getExitInsns()) { + BlockNode insnBlock = BlockUtils.getBlockByInsn(mth, exitInsn); + if (insnBlock != null) { + insnBlock.add(AFlag.DONT_GENERATE); + } + // remove arg from MONITOR_EXIT to allow inline in MONITOR_ENTER + exitInsn.removeArg(0); + exitInsn.add(AFlag.DONT_GENERATE); + } + + BlockNode body = getNextBlock(block); + if (body == null) { + mth.addWarn("Unexpected end of synchronized block"); + return null; + } + BlockNode exit = null; + if (exits.size() == 1) { + exit = getNextBlock(exits.iterator().next()); + } else if (exits.size() > 1) { + cacheSet.clear(); + exit = traverseMonitorExitsCross(body, exits, cacheSet); + } + + stack.push(synchRegion); + if (exit != null) { + stack.addExit(exit); + } else { + for (BlockNode exitBlock : exits) { + // don't add exit blocks which leads to method end blocks ('return', 'throw', etc) + List list = BlockUtils.buildSimplePath(exitBlock); + if (list.isEmpty() || !BlockUtils.isExitBlock(mth, Utils.last(list))) { + stack.addExit(exitBlock); + // we can still try using this as an exit block to make sure it's visited. + exit = exitBlock; + } + } + } + synchRegion.getSubBlocks().add(regionMaker.makeRegion(body)); + stack.pop(); + return exit; + } + + /** + * Traverse from monitor-enter thru successors and collect blocks contains monitor-exit + */ + private static void traverseMonitorExits(SynchronizedRegion region, InsnArg arg, BlockNode block, Set exits, + Set visited) { + visited.add(block); + for (InsnNode insn : block.getInstructions()) { + if (insn.getType() == InsnType.MONITOR_EXIT + && insn.getArgsCount() > 0 + && insn.getArg(0).equals(arg)) { + exits.add(block); + region.getExitInsns().add(insn); + return; + } + } + for (BlockNode node : block.getSuccessors()) { + if (!visited.contains(node)) { + traverseMonitorExits(region, arg, node, exits, visited); + } + } + } + + /** + * Traverse from monitor-enter thru successors and search for exit paths cross + */ + private static BlockNode traverseMonitorExitsCross(BlockNode block, Set exits, Set visited) { + visited.add(block); + for (BlockNode node : block.getCleanSuccessors()) { + boolean cross = true; + for (BlockNode exitBlock : exits) { + boolean p = isPathExists(exitBlock, node); + if (!p) { + cross = false; + break; + } + } + if (cross) { + return node; + } + if (!visited.contains(node)) { + BlockNode res = traverseMonitorExitsCross(node, exits, visited); + if (res != null) { + return res; + } + } + } + return null; + } + + public static void removeSynchronized(MethodNode mth) { + Region startRegion = mth.getRegion(); + List subBlocks = startRegion.getSubBlocks(); + if (!subBlocks.isEmpty() && subBlocks.get(0) instanceof SynchronizedRegion) { + SynchronizedRegion synchRegion = (SynchronizedRegion) subBlocks.get(0); + InsnNode synchInsn = synchRegion.getEnterInsn(); + if (!synchInsn.getArg(0).isThis()) { + LOG.warn("In synchronized method {}, top region not synchronized by 'this' {}", mth, synchInsn); + return; + } + // replace synchronized block with inner region + startRegion.getSubBlocks().set(0, synchRegion.getRegion()); + // remove 'monitor-enter' instruction + InsnRemover.remove(mth, synchInsn); + // remove 'monitor-exit' instruction + for (InsnNode exit : synchRegion.getExitInsns()) { + InsnRemover.remove(mth, exit); + } + // run region cleaner again + CleanRegions.process(mth); + // assume that CodeShrinker will be run after this + } + } +} diff --git a/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java b/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java index 9c6d0b761..b9cbc3648 100644 --- a/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java @@ -1202,4 +1202,48 @@ public class BlockUtils { } return block.get(AType.EXC_CATCH); } + + public static boolean isEqualPaths(BlockNode b1, BlockNode b2) { + if (b1 == b2) { + return true; + } + if (b1 == null || b2 == null) { + return false; + } + return isEqualReturnBlocks(b1, b2) || isEmptySyntheticPath(b1, b2); + } + + private static boolean isEmptySyntheticPath(BlockNode b1, BlockNode b2) { + BlockNode n1 = followEmptyPath(b1); + BlockNode n2 = followEmptyPath(b2); + return n1 == n2 || isEqualReturnBlocks(n1, n2); + } + + public static boolean isEqualReturnBlocks(BlockNode b1, BlockNode b2) { + if (!b1.isReturnBlock() || !b2.isReturnBlock()) { + return false; + } + List b1Insns = b1.getInstructions(); + List b2Insns = b2.getInstructions(); + if (b1Insns.size() != 1 || b2Insns.size() != 1) { + return false; + } + InsnNode i1 = b1Insns.get(0); + InsnNode i2 = b2Insns.get(0); + if (i1.getArgsCount() != i2.getArgsCount()) { + return false; + } + if (i1.getArgsCount() == 0) { + return true; + } + InsnArg firstArg = i1.getArg(0); + InsnArg secondArg = i2.getArg(0); + if (firstArg.isSameConst(secondArg)) { + return true; + } + if (i1.getSourceLine() != i2.getSourceLine()) { + return false; + } + return firstArg.equals(secondArg); + } }