fix: support 'break' extract for nested 'if' (#2697)

This commit is contained in:
Skylot
2025-11-22 21:58:34 +00:00
parent 748f45b386
commit cf0101f13d
2 changed files with 123 additions and 37 deletions
@@ -4,7 +4,7 @@ import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import org.jetbrains.annotations.Nullable;
@@ -42,25 +42,24 @@ public class SwitchBreakVisitor extends AbstractVisitor {
@Override
public void visit(MethodNode mth) throws JadxException {
if (CodeFeaturesAttr.contains(mth, SWITCH)) {
DepthRegionTraversal.traverse(mth, new ExtractCommonBreak());
DepthRegionTraversal.traverse(mth, new RemoveUnreachableBreak());
runSwitchTraverse(mth, ExtractCommonBreak::new);
runSwitchTraverse(mth, RemoveUnreachableBreak::new);
IfRegionVisitor.processIfRequested(mth);
}
}
private static void runSwitchTraverse(MethodNode mth, Supplier<BaseSwitchRegionVisitor> builder) {
DepthRegionTraversal.traverse(mth, new IterativeSwitchRegionVisitor(builder));
}
/**
* Add common 'break' if 'break' or exit insn ('return', 'throw', 'continue') found in all branches.
* Remove exist common break if all branches contain exit insn.
*/
private static final class ExtractCommonBreak extends BaseSwitchRegionVisitor {
@Override
public boolean switchVisitCondition(MethodNode mth, SwitchRegion switchRegion) {
return countBreaks(mth, switchRegion) > 1;
}
@Override
public void processRegion(MethodNode mth, IRegion region) {
if (region instanceof IBranchRegion) {
if (region instanceof IBranchRegion && !(region instanceof SwitchRegion)) {
// if break in all branches extract to parent region
processBranchRegion(mth, region);
}
@@ -112,6 +111,8 @@ public class SwitchBreakVisitor extends AbstractVisitor {
}
if (!dontAddCommonBreak) {
addBreakRegion.add(parentRegion);
// new 'break' might become 'common' for upper branch region, request to run checks again
requestReRun();
}
// removed 'break' may allow to use 'else-if' chain
mth.add(AFlag.REQUEST_IF_REGION_OPTIMIZE);
@@ -120,16 +121,6 @@ public class SwitchBreakVisitor extends AbstractVisitor {
removeBreak(lastParentBlock, parentRegion);
}
}
private int countBreaks(MethodNode mth, IRegion region) {
AtomicInteger count = new AtomicInteger(0);
RegionUtils.visitBlocks(mth, region, block -> {
if (isBreakBlock(block)) {
count.incrementAndGet();
}
});
return count.get();
}
}
private static final class RemoveUnreachableBreak extends BaseSwitchRegionVisitor {
@@ -162,39 +153,58 @@ public class SwitchBreakVisitor extends AbstractVisitor {
}
}
/**
* For every 'switch' region run new instance of provided 'switch' visitor.
* If rerun requested, run traverse for that visitor again.
*/
private static final class IterativeSwitchRegionVisitor extends AbstractRegionVisitor {
private final Supplier<BaseSwitchRegionVisitor> builder;
public IterativeSwitchRegionVisitor(Supplier<BaseSwitchRegionVisitor> builder) {
this.builder = builder;
}
@Override
public void leaveRegion(MethodNode mth, IRegion region) {
if (region instanceof SwitchRegion) {
SwitchRegion switchRegion = (SwitchRegion) region;
BaseSwitchRegionVisitor switchVisitor = builder.get();
switchVisitor.setCurrentSwitch(switchRegion);
boolean runAgain;
int k = 0;
do {
runAgain = false;
DepthRegionTraversal.traverse(mth, switchRegion, switchVisitor);
if (switchVisitor.isReRunRequested()) {
switchVisitor.reset();
runAgain = true;
}
if (k++ > 20) {
// 20 nested 'if' are not expected
mth.addWarnComment("Unexpected iteration count in SwitchBreakVisitor. Please report as an issue");
break;
}
} while (runAgain);
}
}
}
private abstract static class BaseSwitchRegionVisitor extends AbstractRegionVisitor {
protected final Set<IRegion> addBreakRegion = new HashSet<>();
protected final Set<IContainer> cleanupSet = new HashSet<>();
protected SwitchRegion currentSwitch;
private boolean reRunRequested = false;
public abstract void processRegion(MethodNode mth, IRegion region);
public boolean switchVisitCondition(MethodNode mth, SwitchRegion switchRegion) {
return true;
}
@Override
public boolean enterRegion(MethodNode mth, IRegion region) {
if (region instanceof SwitchRegion) {
SwitchRegion switchRegion = (SwitchRegion) region;
this.currentSwitch = switchRegion;
return switchVisitCondition(mth, switchRegion);
}
if (currentSwitch == null) {
return true;
}
processRegion(mth, region);
return true;
}
@Override
public void leaveRegion(MethodNode mth, IRegion region) {
if (region == currentSwitch) {
currentSwitch = null;
addBreakRegion.clear();
cleanupSet.clear();
return;
}
if (addBreakRegion.contains(region)) {
addBreakRegion.remove(region);
region.getSubBlocks().add(SwitchRegionMaker.buildBreakContainer(currentSwitch));
@@ -205,6 +215,27 @@ public class SwitchBreakVisitor extends AbstractVisitor {
}
}
/**
* Method called before visitor rerun
*/
public void reset() {
reRunRequested = false;
addBreakRegion.clear();
cleanupSet.clear();
}
public void requestReRun() {
reRunRequested = true;
}
public boolean isReRunRequested() {
return reRunRequested;
}
public void setCurrentSwitch(SwitchRegion currentSwitch) {
this.currentSwitch = currentSwitch;
}
protected boolean isBreakBlock(@Nullable IBlock block) {
if (block != null) {
InsnNode lastInsn = ListUtils.last(block.getInstructions());
@@ -0,0 +1,55 @@
package jadx.tests.integration.switches;
import jadx.tests.api.IntegrationTest;
import jadx.tests.api.extensions.profiles.TestProfile;
import jadx.tests.api.extensions.profiles.TestWithProfiles;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestSwitchBreak4 extends IntegrationTest {
@SuppressWarnings("SwitchStatementWithTooFewBranches")
public static class TestCls {
private int value;
public void test(int i, boolean b1, boolean b2, boolean b3) {
setValue(-1);
switch (i) {
case 0:
if (b1 == b2) {
setValue(1);
} else if (b1 == b3) {
setValue(2);
} else {
setValue(3);
}
break;
default:
setValue(0);
break;
}
}
private void setValue(int value) {
this.value = value;
}
public void check() {
test(0, true, true, true);
assertThat(value).isEqualTo(1);
test(0, true, false, true);
assertThat(value).isEqualTo(2);
test(0, true, false, false);
assertThat(value).isEqualTo(3);
}
}
@TestWithProfiles({ TestProfile.JAVA11, TestProfile.D8_J11 })
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.countString(2, "break;")
.containsOne("} else if (")
.containsOne("} else {");
}
}