fix: improve common 'break' extract checks (#2697)

This commit is contained in:
Skylot
2025-11-20 19:21:58 +00:00
parent 6aeaf6aca9
commit 9bf079aad4
7 changed files with 163 additions and 32 deletions
@@ -30,7 +30,17 @@ public class IfRegionVisitor extends AbstractVisitor {
process(mth);
}
public static void process(MethodNode mth) {
public static void processIfRequested(MethodNode mth) {
if (mth.contains(AFlag.REQUEST_IF_REGION_OPTIMIZE)) {
try {
process(mth);
} finally {
mth.remove(AFlag.REQUEST_IF_REGION_OPTIMIZE);
}
}
}
private static void process(MethodNode mth) {
TernaryMod.process(mth);
DepthRegionTraversal.traverse(mth, PROCESS_IF_REGION_VISITOR);
DepthRegionTraversal.traverseIterative(mth, REMOVE_REDUNDANT_ELSE_VISITOR);
@@ -48,7 +58,7 @@ public class IfRegionVisitor extends AbstractVisitor {
}
}
@SuppressWarnings({ "UnnecessaryReturnStatement", "StatementWithEmptyBody" })
@SuppressWarnings({ "UnnecessaryReturnStatement" })
private static void orderBranches(MethodNode mth, IfRegion ifRegion) {
if (RegionUtils.isEmpty(ifRegion.getElseRegion())) {
return;
@@ -158,7 +168,7 @@ public class IfRegionVisitor extends AbstractVisitor {
}
}
@SuppressWarnings("StatementWithEmptyBody")
@SuppressWarnings("UnnecessaryParentheses")
private static boolean removeRedundantElseBlock(MethodNode mth, IfRegion ifRegion) {
if (ifRegion.getElseRegion() == null) {
return false;
@@ -53,10 +53,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor
@Override
public void visit(MethodNode mth) {
DepthRegionTraversal.traverse(mth, this);
if (mth.contains(AFlag.REQUEST_IF_REGION_OPTIMIZE)) {
IfRegionVisitor.process(mth);
mth.remove(AFlag.REQUEST_IF_REGION_OPTIMIZE);
}
IfRegionVisitor.processIfRequested(mth);
}
@Override
@@ -1,5 +1,6 @@
package jadx.core.dex.visitors.regions;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@@ -22,6 +23,8 @@ import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.regions.maker.SwitchRegionMaker;
import jadx.core.utils.BlockInsnPair;
import jadx.core.utils.BlockParentContainer;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.ListUtils;
import jadx.core.utils.RegionUtils;
@@ -41,9 +44,14 @@ public class SwitchBreakVisitor extends AbstractVisitor {
if (CodeFeaturesAttr.contains(mth, SWITCH)) {
DepthRegionTraversal.traverse(mth, new ExtractCommonBreak());
DepthRegionTraversal.traverse(mth, new RemoveUnreachableBreak());
IfRegionVisitor.processIfRequested(mth);
}
}
/**
* Add common 'break' if 'break' or exit insn ('return', 'throw', 'continue') found in all branches.
* Remove exist common break if all branches contain exit insn.
*/
private static final class ExtractCommonBreak extends BaseSwitchRegionVisitor {
@Override
public boolean switchVisitCondition(MethodNode mth, SwitchRegion switchRegion) {
@@ -54,11 +62,11 @@ public class SwitchBreakVisitor extends AbstractVisitor {
public void processRegion(MethodNode mth, IRegion region) {
if (region instanceof IBranchRegion) {
// if break in all branches extract to parent region
processBranchRegion(region);
processBranchRegion(mth, region);
}
}
private void processBranchRegion(IRegion region) {
private void processBranchRegion(MethodNode mth, IRegion region) {
IRegion parentRegion = region.getParent();
if (parentRegion.contains(AFlag.FALL_THROUGH)) {
// fallthrough case, can't extract break
@@ -76,40 +84,37 @@ public class SwitchBreakVisitor extends AbstractVisitor {
}
}
List<IContainer> branches = ((IBranchRegion) region).getBranches();
boolean removeBranchBreaks = false;
boolean removeCommonBreak = true; // all branches contain exit insns, common break is unreachable
List<BlockParentContainer> forBreakRemove = new ArrayList<>();
for (IContainer branch : branches) {
if (branch == null) {
removeCommonBreak = false;
continue;
}
IBlock lastBlock = RegionUtils.getLastBlock(branch);
InsnNode lastInsn = BlockUtils.getLastInsn(lastBlock);
if (lastInsn == null) {
BlockInsnPair last = RegionUtils.getLastInsnWithBlock(branch);
if (last == null) {
return;
}
InsnNode lastInsn = last.getInsn();
if (lastInsn.getType() == InsnType.BREAK) {
removeBranchBreaks = true;
IBlock block = last.getBlock();
IContainer parent = RegionUtils.getBlockContainer(branch, block);
forBreakRemove.add(new BlockParentContainer(parent, block));
removeCommonBreak = false;
} else if (!lastInsn.isExitEdgeInsn()) {
removeCommonBreak = false;
}
}
if (removeBranchBreaks) {
if (!forBreakRemove.isEmpty()) {
// common 'break' confirmed
for (IContainer branch : branches) {
if (branch == null) {
continue;
}
// remove breaks from all branches
IBlock lastBlock = RegionUtils.getLastBlock(branch);
if (lastBlock != null) {
removeBreak(lastBlock, branch);
}
for (BlockParentContainer breakData : forBreakRemove) {
removeBreak(breakData.getBlock(), breakData.getParent());
}
if (!dontAddCommonBreak) {
addBreakRegion.add(parentRegion);
}
// removed 'break' may allow to use 'else-if' chain
mth.add(AFlag.REQUEST_IF_REGION_OPTIMIZE);
}
if (removeCommonBreak && lastParentBlock != null) {
removeBreak(lastParentBlock, parentRegion);
@@ -2,19 +2,19 @@ package jadx.core.utils;
import java.util.Objects;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.InsnNode;
public class BlockInsnPair {
private final BlockNode block;
private final IBlock block;
private final InsnNode insn;
public BlockInsnPair(BlockNode block, InsnNode insn) {
public BlockInsnPair(IBlock block, InsnNode insn) {
this.block = block;
this.insn = insn;
}
public BlockNode getBlock() {
public IBlock getBlock() {
return block;
}
@@ -0,0 +1,30 @@
package jadx.core.utils;
import java.util.Objects;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
public class BlockParentContainer {
private final IContainer parent;
private final IBlock block;
public BlockParentContainer(IContainer parent, IBlock block) {
this.parent = Objects.requireNonNull(parent);
this.block = Objects.requireNonNull(block);
}
public IBlock getBlock() {
return block;
}
public IContainer getParent() {
return parent;
}
@Override
public String toString() {
return "BlockParentContainer{" + block + ", parent=" + parent + '}';
}
}
@@ -3,6 +3,7 @@ package jadx.core.utils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Predicate;
@@ -153,6 +154,40 @@ public class RegionUtils {
}
}
public static BlockInsnPair getLastInsnWithBlock(IContainer container) {
if (container instanceof IBlock) {
IBlock block = (IBlock) container;
InsnNode lastInsn = ListUtils.last(block.getInstructions());
if (lastInsn == null) {
return null;
}
return new BlockInsnPair(block, lastInsn);
}
if (container instanceof IBranchRegion) {
List<IContainer> branches = ((IBranchRegion) container).getBranches();
long count = branches.stream().filter(Objects::nonNull).count();
if (count == 1) {
// single branch
for (IContainer branch : branches) {
if (branch != null) {
return getLastInsnWithBlock(branch);
}
}
}
// several last instructions
return null;
}
if (container instanceof IRegion) {
IRegion region = (IRegion) container;
List<IContainer> blocks = region.getSubBlocks();
if (blocks.isEmpty()) {
return null;
}
return getLastInsnWithBlock(ListUtils.last(blocks));
}
throw new JadxRuntimeException(unknownContainerType(container));
}
public static IBlock getLastBlock(IContainer container) {
if (container instanceof IBlock) {
return (IBlock) container;
@@ -439,10 +474,11 @@ public class RegionUtils {
return true;
}
public static IContainer getBlockContainer(IContainer container, BlockNode block) {
public static IContainer getBlockContainer(IContainer container, IBlock block) {
if (container instanceof IBlock) {
return container == block ? container : null;
} else if (container instanceof IRegion) {
}
if (container instanceof IRegion) {
IRegion region = (IRegion) container;
for (IContainer c : region.getSubBlocks()) {
IContainer res = getBlockContainer(c, block);
@@ -451,9 +487,8 @@ public class RegionUtils {
}
}
return null;
} else {
throw new JadxRuntimeException(unknownContainerType(container));
}
throw new JadxRuntimeException(unknownContainerType(container));
}
/**
@@ -0,0 +1,54 @@
package jadx.tests.integration.switches;
import jadx.tests.api.IntegrationTest;
import jadx.tests.api.extensions.profiles.TestProfile;
import jadx.tests.api.extensions.profiles.TestWithProfiles;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestSwitchBreak3 extends IntegrationTest {
@SuppressWarnings("SwitchStatementWithTooFewBranches")
public static class TestCls {
private int value;
public void test(int i, boolean b1, boolean b2, boolean b3) {
setValue(-1);
switch (i) {
case 0:
if (b1 == b2) {
setValue(1);
// no break here;
} else if (b1 == b3) {
setValue(2);
// no break here;
}
break;
default:
setValue(0);
break;
}
}
private void setValue(int value) {
this.value = value;
}
public void check() {
test(0, true, true, true);
assertThat(value).isEqualTo(1);
test(0, true, false, true);
assertThat(value).isEqualTo(2);
test(1, true, true, true);
assertThat(value).isEqualTo(0);
}
}
@TestWithProfiles({ TestProfile.JAVA11, TestProfile.D8_J11 })
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.countString(2, "break;")
.containsOne("} else if (");
}
}