From 6aeaf6aca936433912a0cb5f6fd69a411a50d416 Mon Sep 17 00:00:00 2001 From: Skylot <118523+skylot@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:43:43 +0000 Subject: [PATCH] fix: extract common switch break, remove unreachable (#2697) --- jadx-core/src/main/java/jadx/core/Jadx.java | 3 + .../dex/attributes/nodes/RegionRefAttr.java | 2 +- .../jadx/core/dex/nodes/InsnContainer.java | 10 +- .../java/jadx/core/dex/nodes/InsnNode.java | 13 + .../jadx/core/dex/regions/SwitchRegion.java | 4 +- .../regions/AbstractRegionVisitor.java | 2 +- .../visitors/regions/PostProcessRegions.java | 67 +---- .../visitors/regions/SwitchBreakVisitor.java | 231 ++++++++++++++++++ .../visitors/regions/TracedRegionVisitor.java | 6 +- .../visitors/regions/maker/RegionStack.java | 5 +- .../regions/maker/SwitchRegionMaker.java | 84 ++++++- .../main/java/jadx/core/utils/ListUtils.java | 5 +- .../java/jadx/core/utils/RegionUtils.java | 36 ++- .../integration/switches/TestSwitch2.java | 5 +- .../switches/TestSwitchBreak2.java | 53 ++++ .../switches/TestSwitchWithTryCatch.java | 6 +- 16 files changed, 433 insertions(+), 99 deletions(-) create mode 100644 jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java create mode 100644 jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak2.java diff --git a/jadx-core/src/main/java/jadx/core/Jadx.java b/jadx-core/src/main/java/jadx/core/Jadx.java index 77d310cdf..fd1c9f3a7 100644 --- a/jadx-core/src/main/java/jadx/core/Jadx.java +++ b/jadx-core/src/main/java/jadx/core/Jadx.java @@ -66,6 +66,7 @@ import jadx.core.dex.visitors.regions.IfRegionVisitor; import jadx.core.dex.visitors.regions.LoopRegionVisitor; import jadx.core.dex.visitors.regions.RegionMakerVisitor; import jadx.core.dex.visitors.regions.ReturnVisitor; +import jadx.core.dex.visitors.regions.SwitchBreakVisitor; import jadx.core.dex.visitors.regions.SwitchOverStringVisitor; import jadx.core.dex.visitors.regions.variables.ProcessVariables; import jadx.core.dex.visitors.rename.CodeRenameVisitor; @@ -196,12 +197,14 @@ public class Jadx { passes.add(new FixAccessModifiers()); passes.add(new ClassModifier()); passes.add(new LoopRegionVisitor()); + passes.add(new SwitchBreakVisitor()); if (args.isInlineMethods()) { passes.add(new MarkMethodsForInline()); } passes.add(new ProcessVariables()); passes.add(new ApplyVariableNames()); + passes.add(new PrepareForCodeGen()); if (args.isCfgOutput()) { passes.add(DotGraphVisitor.dumpRegions()); diff --git a/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/RegionRefAttr.java b/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/RegionRefAttr.java index 76d47ddef..41f1dfc40 100644 --- a/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/RegionRefAttr.java +++ b/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/RegionRefAttr.java @@ -25,6 +25,6 @@ public class RegionRefAttr implements IJadxAttribute { @Override public String toString() { - return "RegionRef:" + region; + return "RegionRef:" + region.baseString(); } } diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/InsnContainer.java b/jadx-core/src/main/java/jadx/core/dex/nodes/InsnContainer.java index 537712fbc..bc3e4aa46 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/InsnContainer.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/InsnContainer.java @@ -1,6 +1,6 @@ package jadx.core.dex.nodes; -import java.util.Collections; +import java.util.ArrayList; import java.util.List; import jadx.core.dex.attributes.AttrNode; @@ -14,7 +14,9 @@ public final class InsnContainer extends AttrNode implements IBlock { private final List insns; public InsnContainer(InsnNode insn) { - this.insns = Collections.singletonList(insn); + List list = new ArrayList<>(1); + list.add(insn); + this.insns = list; } public InsnContainer(List insns) { @@ -28,11 +30,11 @@ public final class InsnContainer extends AttrNode implements IBlock { @Override public String baseString() { - return Integer.toString(insns.size()); + return "IC"; } @Override public String toString() { - return "InsnContainer:" + insns.size(); + return "InsnContainer"; } } diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java b/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java index 086e8abd5..7923879e7 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java @@ -217,6 +217,19 @@ public class InsnNode extends LineAttrNode { } } + public boolean isExitEdgeInsn() { + switch (getType()) { + case RETURN: + case THROW: + case CONTINUE: + case BREAK: + return true; + + default: + return false; + } + } + public boolean canRemoveResult() { switch (getType()) { case INVOKE: diff --git a/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java b/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java index 045457df9..d3c16e771 100644 --- a/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java +++ b/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java @@ -86,13 +86,13 @@ public final class SwitchRegion extends AbstractRegion implements IBranchRegion @Override public String baseString() { - return header.baseString(); + return "SW:" + header.baseString(); } @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append("Switch: ").append(cases.size()); + sb.append("Switch: ").append(header.baseString()); for (CaseInfo caseInfo : cases) { List keyStrings = Utils.collectionMap(caseInfo.getKeys(), k -> k == DEFAULT_CASE_KEY ? "default" : k.toString()); diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/AbstractRegionVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/AbstractRegionVisitor.java index 51acf5a6d..f09d9ff61 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/AbstractRegionVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/AbstractRegionVisitor.java @@ -12,7 +12,7 @@ public abstract class AbstractRegionVisitor implements IRegionVisitor { } @Override - public void processBlock(MethodNode mth, IBlock container) { + public void processBlock(MethodNode mth, IBlock block) { } @Override diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/PostProcessRegions.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/PostProcessRegions.java index 0e045e8a1..8e5242841 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/PostProcessRegions.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/PostProcessRegions.java @@ -1,20 +1,11 @@ package jadx.core.dex.visitors.regions; -import java.util.ArrayList; import java.util.Collections; -import java.util.HashSet; import java.util.List; -import java.util.Set; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.nodes.EdgeInsnAttr; -import jadx.core.dex.instructions.InsnType; import jadx.core.dex.nodes.BlockNode; -import jadx.core.dex.nodes.IBlock; import jadx.core.dex.nodes.IContainer; import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.InsnContainer; @@ -23,11 +14,9 @@ import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.regions.Region; import jadx.core.dex.regions.SwitchRegion; import jadx.core.dex.regions.loops.LoopRegion; -import jadx.core.utils.RegionUtils; - -final class PostProcessRegions extends AbstractRegionVisitor { - private static final Logger LOG = LoggerFactory.getLogger(PostProcessRegions.class); +import jadx.core.dex.visitors.regions.maker.SwitchRegionMaker; +public final class PostProcessRegions extends AbstractRegionVisitor { private static final IRegionVisitor INSTANCE = new PostProcessRegions(); static void process(MethodNode mth) { @@ -41,8 +30,7 @@ final class PostProcessRegions extends AbstractRegionVisitor { LoopRegion loop = (LoopRegion) region; loop.mergePreCondition(); } else if (region instanceof SwitchRegion) { - // insert 'break' in switch cases (run after try/catch insertion) - processSwitch(mth, (SwitchRegion) region); + SwitchRegionMaker.insertBreaks(mth, (SwitchRegion) region); } else if (region instanceof Region) { insertEdgeInsn((Region) region); } @@ -76,55 +64,6 @@ final class PostProcessRegions extends AbstractRegionVisitor { region.add(new InsnContainer(insns)); } - private static void processSwitch(MethodNode mth, SwitchRegion sw) { - for (IContainer c : sw.getBranches()) { - if (c instanceof Region) { - Set blocks = new HashSet<>(); - RegionUtils.getAllRegionBlocks(c, blocks); - if (blocks.isEmpty()) { - addBreakToContainer((Region) c); - } else { - for (IBlock block : blocks) { - if (block instanceof BlockNode) { - addBreakForBlock(mth, c, blocks, (BlockNode) block); - } - } - } - } - } - } - - private static void addBreakToContainer(Region c) { - if (RegionUtils.hasExitEdge(c)) { - return; - } - List insns = new ArrayList<>(1); - insns.add(new InsnNode(InsnType.BREAK, 0)); - c.add(new InsnContainer(insns)); - } - - private static void addBreakForBlock(MethodNode mth, IContainer c, Set blocks, BlockNode bn) { - for (BlockNode s : bn.getCleanSuccessors()) { - if (!blocks.contains(s) - && !bn.contains(AFlag.ADDED_TO_REGION) - && !s.contains(AFlag.FALL_THROUGH)) { - addBreak(mth, c, bn); - return; - } - } - } - - private static void addBreak(MethodNode mth, IContainer c, BlockNode bn) { - IContainer blockContainer = RegionUtils.getBlockContainer(c, bn); - if (blockContainer instanceof Region) { - addBreakToContainer((Region) blockContainer); - } else if (c instanceof Region) { - addBreakToContainer((Region) c); - } else { - LOG.warn("Can't insert break, container: {}, block: {}, mth: {}", blockContainer, bn, mth); - } - } - private PostProcessRegions() { // singleton } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java new file mode 100644 index 000000000..6b49eda4b --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java @@ -0,0 +1,231 @@ +package jadx.core.dex.visitors.regions; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import org.jetbrains.annotations.Nullable; + +import jadx.core.dex.attributes.AFlag; +import jadx.core.dex.attributes.AType; +import jadx.core.dex.attributes.nodes.CodeFeaturesAttr; +import jadx.core.dex.attributes.nodes.RegionRefAttr; +import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.nodes.IBlock; +import jadx.core.dex.nodes.IBranchRegion; +import jadx.core.dex.nodes.IContainer; +import jadx.core.dex.nodes.IRegion; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.regions.SwitchRegion; +import jadx.core.dex.visitors.AbstractVisitor; +import jadx.core.dex.visitors.JadxVisitor; +import jadx.core.dex.visitors.regions.maker.SwitchRegionMaker; +import jadx.core.utils.BlockUtils; +import jadx.core.utils.ListUtils; +import jadx.core.utils.RegionUtils; +import jadx.core.utils.exceptions.JadxException; + +import static jadx.core.dex.attributes.nodes.CodeFeaturesAttr.CodeFeature.SWITCH; + +@JadxVisitor( + name = "SwitchBreakVisitor", + desc = "Optimize 'break' instruction: common code extract, remove unreachable", + runAfter = LoopRegionVisitor.class // can add 'continue' at case end +) +public class SwitchBreakVisitor extends AbstractVisitor { + + @Override + public void visit(MethodNode mth) throws JadxException { + if (CodeFeaturesAttr.contains(mth, SWITCH)) { + DepthRegionTraversal.traverse(mth, new ExtractCommonBreak()); + DepthRegionTraversal.traverse(mth, new RemoveUnreachableBreak()); + } + } + + private static final class ExtractCommonBreak extends BaseSwitchRegionVisitor { + @Override + public boolean switchVisitCondition(MethodNode mth, SwitchRegion switchRegion) { + return countBreaks(mth, switchRegion) > 1; + } + + @Override + public void processRegion(MethodNode mth, IRegion region) { + if (region instanceof IBranchRegion) { + // if break in all branches extract to parent region + processBranchRegion(region); + } + } + + private void processBranchRegion(IRegion region) { + IRegion parentRegion = region.getParent(); + if (parentRegion.contains(AFlag.FALL_THROUGH)) { + // fallthrough case, can't extract break + return; + } + boolean dontAddCommonBreak = false; + IBlock lastParentBlock = RegionUtils.getLastBlock(parentRegion); + if (BlockUtils.containsExitInsn(lastParentBlock)) { + if (isBreakBlock(lastParentBlock)) { + // parent block already contains 'break' + dontAddCommonBreak = true; + } else { + // can't add 'break' after 'return', 'throw' or 'continue' + return; + } + } + List branches = ((IBranchRegion) region).getBranches(); + boolean removeBranchBreaks = false; + boolean removeCommonBreak = true; // all branches contain exit insns, common break is unreachable + for (IContainer branch : branches) { + if (branch == null) { + removeCommonBreak = false; + continue; + } + IBlock lastBlock = RegionUtils.getLastBlock(branch); + InsnNode lastInsn = BlockUtils.getLastInsn(lastBlock); + if (lastInsn == null) { + return; + } + if (lastInsn.getType() == InsnType.BREAK) { + removeBranchBreaks = true; + removeCommonBreak = false; + } else if (!lastInsn.isExitEdgeInsn()) { + removeCommonBreak = false; + } + } + if (removeBranchBreaks) { + // common 'break' confirmed + for (IContainer branch : branches) { + if (branch == null) { + continue; + } + // remove breaks from all branches + IBlock lastBlock = RegionUtils.getLastBlock(branch); + if (lastBlock != null) { + removeBreak(lastBlock, branch); + } + } + if (!dontAddCommonBreak) { + addBreakRegion.add(parentRegion); + } + } + if (removeCommonBreak && lastParentBlock != null) { + removeBreak(lastParentBlock, parentRegion); + } + } + + private int countBreaks(MethodNode mth, IRegion region) { + AtomicInteger count = new AtomicInteger(0); + RegionUtils.visitBlocks(mth, region, block -> { + if (isBreakBlock(block)) { + count.incrementAndGet(); + } + }); + return count.get(); + } + } + + private static final class RemoveUnreachableBreak extends BaseSwitchRegionVisitor { + @Override + public void processRegion(MethodNode mth, IRegion region) { + List subBlocks = region.getSubBlocks(); + IContainer lastContainer = ListUtils.last(subBlocks); + if (lastContainer instanceof IBlock) { + IBlock block = (IBlock) lastContainer; + if (isBreakBlock(block) && isPrevInsnIsExit(block, subBlocks)) { + removeBreak(block, region); + } + } + } + + private boolean isPrevInsnIsExit(IBlock breakBlock, List subBlocks) { + InsnNode prevInsn = null; + if (breakBlock.getInstructions().size() > 1) { + // check prev insn in same block + List insns = breakBlock.getInstructions(); + prevInsn = insns.get(insns.size() - 2); + } else if (subBlocks.size() > 1) { + IContainer prev = subBlocks.get(subBlocks.size() - 2); + if (prev instanceof IBlock) { + List insns = ((IBlock) prev).getInstructions(); + prevInsn = ListUtils.last(insns); + } + } + return prevInsn != null && prevInsn.isExitEdgeInsn(); + } + } + + private abstract static class BaseSwitchRegionVisitor extends AbstractRegionVisitor { + protected final Set addBreakRegion = new HashSet<>(); + protected final Set cleanupSet = new HashSet<>(); + protected SwitchRegion currentSwitch; + + public abstract void processRegion(MethodNode mth, IRegion region); + + public boolean switchVisitCondition(MethodNode mth, SwitchRegion switchRegion) { + return true; + } + + @Override + public boolean enterRegion(MethodNode mth, IRegion region) { + if (region instanceof SwitchRegion) { + SwitchRegion switchRegion = (SwitchRegion) region; + this.currentSwitch = switchRegion; + return switchVisitCondition(mth, switchRegion); + } + if (currentSwitch == null) { + return true; + } + processRegion(mth, region); + return true; + } + + @Override + public void leaveRegion(MethodNode mth, IRegion region) { + if (region == currentSwitch) { + currentSwitch = null; + addBreakRegion.clear(); + cleanupSet.clear(); + return; + } + if (addBreakRegion.contains(region)) { + addBreakRegion.remove(region); + region.getSubBlocks().add(SwitchRegionMaker.buildBreakContainer(currentSwitch)); + } + if (cleanupSet.contains(region)) { + cleanupSet.remove(region); + region.getSubBlocks().removeIf(r -> r.contains(AFlag.REMOVE)); + } + } + + protected boolean isBreakBlock(@Nullable IBlock block) { + if (block != null) { + InsnNode lastInsn = ListUtils.last(block.getInstructions()); + if (lastInsn != null && lastInsn.getType() == InsnType.BREAK) { + RegionRefAttr regionRefAttr = lastInsn.get(AType.REGION_REF); + return regionRefAttr != null && regionRefAttr.getRegion() == currentSwitch; + } + } + return false; + } + + protected void removeBreak(IBlock breakBlock, IContainer parentContainer) { + List instructions = breakBlock.getInstructions(); + InsnNode last = ListUtils.last(instructions); + if (last != null && last.getType() == InsnType.BREAK) { + ListUtils.removeLast(instructions); + if (instructions.isEmpty()) { + breakBlock.add(AFlag.REMOVE); + cleanupSet.add(parentContainer); + } + } + } + } + + @Override + public String getName() { + return "SwitchBreakVisitor"; + } +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/TracedRegionVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/TracedRegionVisitor.java index 9d66c84d1..3705082c3 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/TracedRegionVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/TracedRegionVisitor.java @@ -18,12 +18,12 @@ public abstract class TracedRegionVisitor implements IRegionVisitor { } @Override - public void processBlock(MethodNode mth, IBlock container) { + public void processBlock(MethodNode mth, IBlock block) { IRegion curRegion = regionStack.peek(); - processBlockTraced(mth, container, curRegion); + processBlockTraced(mth, block, curRegion); } - public abstract void processBlockTraced(MethodNode mth, IBlock container, IRegion currentRegion); + public abstract void processBlockTraced(MethodNode mth, IBlock block, IRegion parentRegion); @Override public void leaveRegion(MethodNode mth, IRegion region) { diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionStack.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionStack.java index ed717f9a0..1a1441938 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionStack.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/RegionStack.java @@ -6,6 +6,7 @@ import java.util.Deque; import java.util.HashSet; import java.util.Set; +import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -83,7 +84,7 @@ final class RegionStack { * * @param exit boundary node, null will be ignored */ - public void addExit(BlockNode exit) { + public void addExit(@Nullable BlockNode exit) { if (exit != null) { curState.exits.add(exit); } @@ -95,7 +96,7 @@ final class RegionStack { } } - public void removeExit(BlockNode exit) { + public void removeExit(@Nullable BlockNode exit) { if (exit != null) { curState.exits.remove(exit); } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java index 6e7bd2e14..800b58f38 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java @@ -19,17 +19,24 @@ import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.IContainer; import jadx.core.dex.nodes.IRegion; +import jadx.core.dex.nodes.InsnContainer; import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.regions.Region; import jadx.core.dex.regions.SwitchRegion; +import jadx.core.dex.visitors.regions.AbstractRegionVisitor; +import jadx.core.dex.visitors.regions.DepthRegionTraversal; +import jadx.core.dex.visitors.regions.SwitchBreakVisitor; import jadx.core.utils.BlockUtils; +import jadx.core.utils.ListUtils; import jadx.core.utils.RegionUtils; import jadx.core.utils.Utils; +import jadx.core.utils.blocks.BlockSet; import jadx.core.utils.exceptions.JadxRuntimeException; -final class SwitchRegionMaker { +public final class SwitchRegionMaker { private final MethodNode mth; private final RegionMaker regionMaker; @@ -61,21 +68,32 @@ final class SwitchRegionMaker { BlockNode out = calcSwitchOut(block, insn, stack); stack.addExit(out); - processFallThroughCases(sw, out, stack, blocksMap); + addCases(sw, out, stack, blocksMap); removeEmptyCases(insn, sw, defCase); stack.pop(); return out; } - private void processFallThroughCases(SwitchRegion sw, @Nullable BlockNode out, + /** + * Insert 'break' for all cases in switch region + * Executed in {@link jadx.core.dex.visitors.regions.PostProcessRegions} after try/catch wrap to + * handle all blocks + */ + public static void insertBreaks(MethodNode mth, SwitchRegion sw) { + for (SwitchRegion.CaseInfo caseInfo : sw.getCases()) { + insertBreaksForCase(mth, sw, caseInfo.getContainer()); + } + } + + private void addCases(SwitchRegion sw, @Nullable BlockNode out, RegionStack stack, Map> blocksMap) { Map fallThroughCases = new LinkedHashMap<>(); if (out != null) { // detect fallthrough cases BitSet caseBlocks = BlockUtils.blocksToBitSet(mth, blocksMap.keySet()); - caseBlocks.clear(out.getId()); - for (BlockNode successor : sw.getHeader().getCleanSuccessors()) { + caseBlocks.clear(out.getPos()); + for (BlockNode successor : sw.getHeader().getSuccessors()) { BitSet df = successor.getDomFrontier(); if (df.intersects(caseBlocks)) { BlockNode fallThroughBlock = getOneIntersectionBlock(out, caseBlocks, df); @@ -93,31 +111,30 @@ final class SwitchRegionMaker { } } } - for (Map.Entry> entry : blocksMap.entrySet()) { List keysList = entry.getValue(); BlockNode caseBlock = entry.getKey(); + Region caseRegion; if (stack.containsExit(caseBlock)) { - sw.addCase(keysList, new Region(stack.peekRegion())); + caseRegion = new Region(stack.peekRegion()); } else { BlockNode next = fallThroughCases.get(caseBlock); stack.addExit(next); - Region caseRegion = regionMaker.makeRegion(caseBlock); + caseRegion = regionMaker.makeRegion(caseBlock); stack.removeExit(next); if (next != null) { next.add(AFlag.FALL_THROUGH); caseRegion.add(AFlag.FALL_THROUGH); } - sw.addCase(keysList, caseRegion); - // 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor } + sw.addCase(keysList, caseRegion); } } @Nullable private BlockNode getOneIntersectionBlock(BlockNode out, BitSet caseBlocks, BitSet fallThroughSet) { BitSet caseExits = BlockUtils.copyBlocksBitSet(mth, fallThroughSet); - caseExits.clear(out.getId()); + caseExits.clear(out.getPos()); caseExits.and(caseBlocks); return BlockUtils.bitSetToOneBlock(mth, caseExits); } @@ -341,4 +358,49 @@ final class SwitchRegionMaker { } return inserted; } + + /** + * Add break to every exit edge from 'case' region. + * 'Break' optimizations (code duplication, unreachable, etc.) will be done at + * {@link SwitchBreakVisitor} + */ + private static void insertBreaksForCase(MethodNode mth, SwitchRegion switchRegion, IContainer caseContainer) { + BlockSet caseBlocks = new BlockSet(mth); + RegionUtils.visitBlockNodes(mth, caseContainer, caseBlocks::add); + DepthRegionTraversal.traverse(mth, caseContainer, new AbstractRegionVisitor() { + @Override + public void leaveRegion(MethodNode mth, IRegion region) { + boolean insertBreak = false; + if (region == caseContainer) { + // top region + insertBreak = true; + } else { + IContainer lastContainer = ListUtils.last(region.getSubBlocks()); + if (lastContainer instanceof BlockNode) { + BlockNode lastBlock = (BlockNode) lastContainer; + for (BlockNode successor : lastBlock.getSuccessors()) { + if (!caseBlocks.contains(successor)) { + insertBreak = true; + break; + } + } + } + } + if (insertBreak && canAppendBreak(region)) { + region.getSubBlocks().add(buildBreakContainer(switchRegion)); + } + } + }); + } + + public static boolean canAppendBreak(IRegion region) { + return !region.contains(AFlag.FALL_THROUGH) && !RegionUtils.hasExitBlock(region); + } + + public static InsnContainer buildBreakContainer(SwitchRegion switchRegion) { + InsnNode breakInsn = new InsnNode(InsnType.BREAK, 0); + breakInsn.add(AFlag.SYNTHETIC); + breakInsn.addAttr(new RegionRefAttr(switchRegion)); + return new InsnContainer(breakInsn); + } } diff --git a/jadx-core/src/main/java/jadx/core/utils/ListUtils.java b/jadx-core/src/main/java/jadx/core/utils/ListUtils.java index ff32a5cf4..ddc82704b 100644 --- a/jadx-core/src/main/java/jadx/core/utils/ListUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/ListUtils.java @@ -67,7 +67,10 @@ public class ListUtils { return list.get(0); } - public static T last(List list) { + public static @Nullable T last(List list) { + if (list == null || list.isEmpty()) { + return null; + } return list.get(list.size() - 1); } diff --git a/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java b/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java index 5ed98d14d..18e5ea30f 100644 --- a/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java @@ -51,9 +51,8 @@ public class RegionUtils { return true; } if (container instanceof IRegion) { - IRegion region = (IRegion) container; - List blocks = region.getSubBlocks(); - return !blocks.isEmpty() && hasExitEdge(blocks.get(blocks.size() - 1)); + IContainer last = Utils.last(((IRegion) container).getSubBlocks()); + return last != null && hasExitEdge(last); } throw new JadxRuntimeException(unknownContainerType(container)); } @@ -81,6 +80,26 @@ public class RegionUtils { } } + public static @Nullable BlockNode getFirstBlockNode(IContainer container) { + if (container instanceof IBlock) { + if (container instanceof BlockNode) { + return (BlockNode) container; + } + return null; + } + if (container instanceof IBranchRegion) { + return null; + } + if (container instanceof IRegion) { + List blocks = ((IRegion) container).getSubBlocks(); + if (blocks.isEmpty()) { + return null; + } + return getFirstBlockNode(blocks.get(0)); + } + throw new JadxRuntimeException(unknownContainerType(container)); + } + public static int getFirstSourceLine(IContainer container) { if (container instanceof IBlock) { return BlockUtils.getFirstSourceLine((IBlock) container); @@ -517,6 +536,17 @@ public class RegionUtils { }); } + public static void visitBlockNodes(MethodNode mth, IContainer container, Consumer visitor) { + DepthRegionTraversal.traverse(mth, container, new AbstractRegionVisitor() { + @Override + public void processBlock(MethodNode mth, IBlock block) { + if (block instanceof BlockNode) { + visitor.accept((BlockNode) block); + } + } + }); + } + public static void visitRegions(MethodNode mth, IContainer container, Predicate visitor) { DepthRegionTraversal.traverse(mth, container, new AbstractRegionVisitor() { @Override diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java index 23ac0d6f2..ed929c939 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java @@ -57,10 +57,9 @@ public class TestSwitch2 extends IntegrationTest { public void test() { assertThat(getClassNode(TestCls.class)) .code() - // .countString(4, "break;" + .countString(4, "break;") // .countString(2, "return;") - // TODO: remove redundant reak and returns - .countString(5, "break;") + // TODO: remove redundant returns .countString(4, "return;"); } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak2.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak2.java new file mode 100644 index 000000000..d43d7f00d --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak2.java @@ -0,0 +1,53 @@ +package jadx.tests.integration.switches; + +import jadx.tests.api.IntegrationTest; +import jadx.tests.api.extensions.profiles.TestProfile; +import jadx.tests.api.extensions.profiles.TestWithProfiles; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestSwitchBreak2 extends IntegrationTest { + + @SuppressWarnings("SwitchStatementWithTooFewBranches") + public static class TestCls { + private int value; + + public void test(int i, boolean b1, boolean b2) { + setValue(-1); + switch (i) { + case 0: + if (b1 && b2) { + setValue(1); + // no break here; + } else { + setValue(2); + // no break here; + } + break; + default: + setValue(0); + break; + } + } + + private void setValue(int value) { + this.value = value; + } + + public void check() { + test(0, true, true); + assertThat(value).isEqualTo(1); + test(0, true, false); + assertThat(value).isEqualTo(2); + test(1, true, true); + assertThat(value).isEqualTo(0); + } + } + + @TestWithProfiles({ TestProfile.JAVA11, TestProfile.D8_J11 }) + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .countString(2, "break;"); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithTryCatch.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithTryCatch.java index 308d150d2..41ed7505f 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithTryCatch.java +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithTryCatch.java @@ -60,9 +60,7 @@ public class TestSwitchWithTryCatch extends IntegrationTest { public void test() { assertThat(getClassNode(TestCls.class)) .code() - // .countString(3, "break;") - .countString(4, "return;") - // TODO: remove redundant break - .countString(4, "break;"); + .countString(3, "break;") + .countString(4, "return;"); } }