From cf0101f13dc216a9e88d4ede894ee2cbadeb407d Mon Sep 17 00:00:00 2001 From: Skylot <118523+skylot@users.noreply.github.com> Date: Sat, 22 Nov 2025 21:58:34 +0000 Subject: [PATCH] fix: support 'break' extract for nested 'if' (#2697) --- .../visitors/regions/SwitchBreakVisitor.java | 105 ++++++++++++------ .../switches/TestSwitchBreak4.java | 55 +++++++++ 2 files changed, 123 insertions(+), 37 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak4.java diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java index 47c9ce9a5..69236447c 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java @@ -4,7 +4,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import org.jetbrains.annotations.Nullable; @@ -42,25 +42,24 @@ public class SwitchBreakVisitor extends AbstractVisitor { @Override public void visit(MethodNode mth) throws JadxException { if (CodeFeaturesAttr.contains(mth, SWITCH)) { - DepthRegionTraversal.traverse(mth, new ExtractCommonBreak()); - DepthRegionTraversal.traverse(mth, new RemoveUnreachableBreak()); + runSwitchTraverse(mth, ExtractCommonBreak::new); + runSwitchTraverse(mth, RemoveUnreachableBreak::new); IfRegionVisitor.processIfRequested(mth); } } + private static void runSwitchTraverse(MethodNode mth, Supplier builder) { + DepthRegionTraversal.traverse(mth, new IterativeSwitchRegionVisitor(builder)); + } + /** * Add common 'break' if 'break' or exit insn ('return', 'throw', 'continue') found in all branches. * Remove exist common break if all branches contain exit insn. */ private static final class ExtractCommonBreak extends BaseSwitchRegionVisitor { - @Override - public boolean switchVisitCondition(MethodNode mth, SwitchRegion switchRegion) { - return countBreaks(mth, switchRegion) > 1; - } - @Override public void processRegion(MethodNode mth, IRegion region) { - if (region instanceof IBranchRegion) { + if (region instanceof IBranchRegion && !(region instanceof SwitchRegion)) { // if break in all branches extract to parent region processBranchRegion(mth, region); } @@ -112,6 +111,8 @@ public class SwitchBreakVisitor extends AbstractVisitor { } if (!dontAddCommonBreak) { addBreakRegion.add(parentRegion); + // new 'break' might become 'common' for upper branch region, request to run checks again + requestReRun(); } // removed 'break' may allow to use 'else-if' chain mth.add(AFlag.REQUEST_IF_REGION_OPTIMIZE); @@ -120,16 +121,6 @@ public class SwitchBreakVisitor extends AbstractVisitor { removeBreak(lastParentBlock, parentRegion); } } - - private int countBreaks(MethodNode mth, IRegion region) { - AtomicInteger count = new AtomicInteger(0); - RegionUtils.visitBlocks(mth, region, block -> { - if (isBreakBlock(block)) { - count.incrementAndGet(); - } - }); - return count.get(); - } } private static final class RemoveUnreachableBreak extends BaseSwitchRegionVisitor { @@ -162,39 +153,58 @@ public class SwitchBreakVisitor extends AbstractVisitor { } } + /** + * For every 'switch' region run new instance of provided 'switch' visitor. + * If rerun requested, run traverse for that visitor again. + */ + private static final class IterativeSwitchRegionVisitor extends AbstractRegionVisitor { + private final Supplier builder; + + public IterativeSwitchRegionVisitor(Supplier builder) { + this.builder = builder; + } + + @Override + public void leaveRegion(MethodNode mth, IRegion region) { + if (region instanceof SwitchRegion) { + SwitchRegion switchRegion = (SwitchRegion) region; + BaseSwitchRegionVisitor switchVisitor = builder.get(); + switchVisitor.setCurrentSwitch(switchRegion); + boolean runAgain; + int k = 0; + do { + runAgain = false; + DepthRegionTraversal.traverse(mth, switchRegion, switchVisitor); + if (switchVisitor.isReRunRequested()) { + switchVisitor.reset(); + runAgain = true; + } + if (k++ > 20) { + // 20 nested 'if' are not expected + mth.addWarnComment("Unexpected iteration count in SwitchBreakVisitor. Please report as an issue"); + break; + } + } while (runAgain); + } + } + } + private abstract static class BaseSwitchRegionVisitor extends AbstractRegionVisitor { protected final Set addBreakRegion = new HashSet<>(); protected final Set cleanupSet = new HashSet<>(); protected SwitchRegion currentSwitch; + private boolean reRunRequested = false; public abstract void processRegion(MethodNode mth, IRegion region); - public boolean switchVisitCondition(MethodNode mth, SwitchRegion switchRegion) { - return true; - } - @Override public boolean enterRegion(MethodNode mth, IRegion region) { - if (region instanceof SwitchRegion) { - SwitchRegion switchRegion = (SwitchRegion) region; - this.currentSwitch = switchRegion; - return switchVisitCondition(mth, switchRegion); - } - if (currentSwitch == null) { - return true; - } processRegion(mth, region); return true; } @Override public void leaveRegion(MethodNode mth, IRegion region) { - if (region == currentSwitch) { - currentSwitch = null; - addBreakRegion.clear(); - cleanupSet.clear(); - return; - } if (addBreakRegion.contains(region)) { addBreakRegion.remove(region); region.getSubBlocks().add(SwitchRegionMaker.buildBreakContainer(currentSwitch)); @@ -205,6 +215,27 @@ public class SwitchBreakVisitor extends AbstractVisitor { } } + /** + * Method called before visitor rerun + */ + public void reset() { + reRunRequested = false; + addBreakRegion.clear(); + cleanupSet.clear(); + } + + public void requestReRun() { + reRunRequested = true; + } + + public boolean isReRunRequested() { + return reRunRequested; + } + + public void setCurrentSwitch(SwitchRegion currentSwitch) { + this.currentSwitch = currentSwitch; + } + protected boolean isBreakBlock(@Nullable IBlock block) { if (block != null) { InsnNode lastInsn = ListUtils.last(block.getInstructions()); diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak4.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak4.java new file mode 100644 index 000000000..e572fd859 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak4.java @@ -0,0 +1,55 @@ +package jadx.tests.integration.switches; + +import jadx.tests.api.IntegrationTest; +import jadx.tests.api.extensions.profiles.TestProfile; +import jadx.tests.api.extensions.profiles.TestWithProfiles; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestSwitchBreak4 extends IntegrationTest { + + @SuppressWarnings("SwitchStatementWithTooFewBranches") + public static class TestCls { + private int value; + + public void test(int i, boolean b1, boolean b2, boolean b3) { + setValue(-1); + switch (i) { + case 0: + if (b1 == b2) { + setValue(1); + } else if (b1 == b3) { + setValue(2); + } else { + setValue(3); + } + break; + default: + setValue(0); + break; + } + } + + private void setValue(int value) { + this.value = value; + } + + public void check() { + test(0, true, true, true); + assertThat(value).isEqualTo(1); + test(0, true, false, true); + assertThat(value).isEqualTo(2); + test(0, true, false, false); + assertThat(value).isEqualTo(3); + } + } + + @TestWithProfiles({ TestProfile.JAVA11, TestProfile.D8_J11 }) + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .countString(2, "break;") + .containsOne("} else if (") + .containsOne("} else {"); + } +}