diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/IfRegionVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/IfRegionVisitor.java index 711cad11a..65c1d75ee 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/IfRegionVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/IfRegionVisitor.java @@ -30,7 +30,17 @@ public class IfRegionVisitor extends AbstractVisitor { process(mth); } - public static void process(MethodNode mth) { + public static void processIfRequested(MethodNode mth) { + if (mth.contains(AFlag.REQUEST_IF_REGION_OPTIMIZE)) { + try { + process(mth); + } finally { + mth.remove(AFlag.REQUEST_IF_REGION_OPTIMIZE); + } + } + } + + private static void process(MethodNode mth) { TernaryMod.process(mth); DepthRegionTraversal.traverse(mth, PROCESS_IF_REGION_VISITOR); DepthRegionTraversal.traverseIterative(mth, REMOVE_REDUNDANT_ELSE_VISITOR); @@ -48,7 +58,7 @@ public class IfRegionVisitor extends AbstractVisitor { } } - @SuppressWarnings({ "UnnecessaryReturnStatement", "StatementWithEmptyBody" }) + @SuppressWarnings({ "UnnecessaryReturnStatement" }) private static void orderBranches(MethodNode mth, IfRegion ifRegion) { if (RegionUtils.isEmpty(ifRegion.getElseRegion())) { return; @@ -158,7 +168,7 @@ public class IfRegionVisitor extends AbstractVisitor { } } - @SuppressWarnings("StatementWithEmptyBody") + @SuppressWarnings("UnnecessaryParentheses") private static boolean removeRedundantElseBlock(MethodNode mth, IfRegion ifRegion) { if (ifRegion.getElseRegion() == null) { return false; diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/LoopRegionVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/LoopRegionVisitor.java index 09d5608a7..95120e707 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/LoopRegionVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/LoopRegionVisitor.java @@ -53,10 +53,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor @Override public void visit(MethodNode mth) { DepthRegionTraversal.traverse(mth, this); - if (mth.contains(AFlag.REQUEST_IF_REGION_OPTIMIZE)) { - IfRegionVisitor.process(mth); - mth.remove(AFlag.REQUEST_IF_REGION_OPTIMIZE); - } + IfRegionVisitor.processIfRequested(mth); } @Override diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java index 6b49eda4b..47c9ce9a5 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/SwitchBreakVisitor.java @@ -1,5 +1,6 @@ package jadx.core.dex.visitors.regions; +import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -22,6 +23,8 @@ import jadx.core.dex.regions.SwitchRegion; import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.dex.visitors.JadxVisitor; import jadx.core.dex.visitors.regions.maker.SwitchRegionMaker; +import jadx.core.utils.BlockInsnPair; +import jadx.core.utils.BlockParentContainer; import jadx.core.utils.BlockUtils; import jadx.core.utils.ListUtils; import jadx.core.utils.RegionUtils; @@ -41,9 +44,14 @@ public class SwitchBreakVisitor extends AbstractVisitor { if (CodeFeaturesAttr.contains(mth, SWITCH)) { DepthRegionTraversal.traverse(mth, new ExtractCommonBreak()); DepthRegionTraversal.traverse(mth, new RemoveUnreachableBreak()); + IfRegionVisitor.processIfRequested(mth); } } + /** + * Add common 'break' if 'break' or exit insn ('return', 'throw', 'continue') found in all branches. + * Remove exist common break if all branches contain exit insn. + */ private static final class ExtractCommonBreak extends BaseSwitchRegionVisitor { @Override public boolean switchVisitCondition(MethodNode mth, SwitchRegion switchRegion) { @@ -54,11 +62,11 @@ public class SwitchBreakVisitor extends AbstractVisitor { public void processRegion(MethodNode mth, IRegion region) { if (region instanceof IBranchRegion) { // if break in all branches extract to parent region - processBranchRegion(region); + processBranchRegion(mth, region); } } - private void processBranchRegion(IRegion region) { + private void processBranchRegion(MethodNode mth, IRegion region) { IRegion parentRegion = region.getParent(); if (parentRegion.contains(AFlag.FALL_THROUGH)) { // fallthrough case, can't extract break @@ -76,40 +84,37 @@ public class SwitchBreakVisitor extends AbstractVisitor { } } List branches = ((IBranchRegion) region).getBranches(); - boolean removeBranchBreaks = false; boolean removeCommonBreak = true; // all branches contain exit insns, common break is unreachable + List forBreakRemove = new ArrayList<>(); for (IContainer branch : branches) { if (branch == null) { removeCommonBreak = false; continue; } - IBlock lastBlock = RegionUtils.getLastBlock(branch); - InsnNode lastInsn = BlockUtils.getLastInsn(lastBlock); - if (lastInsn == null) { + BlockInsnPair last = RegionUtils.getLastInsnWithBlock(branch); + if (last == null) { return; } + InsnNode lastInsn = last.getInsn(); if (lastInsn.getType() == InsnType.BREAK) { - removeBranchBreaks = true; + IBlock block = last.getBlock(); + IContainer parent = RegionUtils.getBlockContainer(branch, block); + forBreakRemove.add(new BlockParentContainer(parent, block)); removeCommonBreak = false; } else if (!lastInsn.isExitEdgeInsn()) { removeCommonBreak = false; } } - if (removeBranchBreaks) { + if (!forBreakRemove.isEmpty()) { // common 'break' confirmed - for (IContainer branch : branches) { - if (branch == null) { - continue; - } - // remove breaks from all branches - IBlock lastBlock = RegionUtils.getLastBlock(branch); - if (lastBlock != null) { - removeBreak(lastBlock, branch); - } + for (BlockParentContainer breakData : forBreakRemove) { + removeBreak(breakData.getBlock(), breakData.getParent()); } if (!dontAddCommonBreak) { addBreakRegion.add(parentRegion); } + // removed 'break' may allow to use 'else-if' chain + mth.add(AFlag.REQUEST_IF_REGION_OPTIMIZE); } if (removeCommonBreak && lastParentBlock != null) { removeBreak(lastParentBlock, parentRegion); diff --git a/jadx-core/src/main/java/jadx/core/utils/BlockInsnPair.java b/jadx-core/src/main/java/jadx/core/utils/BlockInsnPair.java index 44c56048d..8d5841f42 100644 --- a/jadx-core/src/main/java/jadx/core/utils/BlockInsnPair.java +++ b/jadx-core/src/main/java/jadx/core/utils/BlockInsnPair.java @@ -2,19 +2,19 @@ package jadx.core.utils; import java.util.Objects; -import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.IBlock; import jadx.core.dex.nodes.InsnNode; public class BlockInsnPair { - private final BlockNode block; + private final IBlock block; private final InsnNode insn; - public BlockInsnPair(BlockNode block, InsnNode insn) { + public BlockInsnPair(IBlock block, InsnNode insn) { this.block = block; this.insn = insn; } - public BlockNode getBlock() { + public IBlock getBlock() { return block; } diff --git a/jadx-core/src/main/java/jadx/core/utils/BlockParentContainer.java b/jadx-core/src/main/java/jadx/core/utils/BlockParentContainer.java new file mode 100644 index 000000000..cd1b6aa05 --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/utils/BlockParentContainer.java @@ -0,0 +1,30 @@ +package jadx.core.utils; + +import java.util.Objects; + +import jadx.core.dex.nodes.IBlock; +import jadx.core.dex.nodes.IContainer; + +public class BlockParentContainer { + + private final IContainer parent; + private final IBlock block; + + public BlockParentContainer(IContainer parent, IBlock block) { + this.parent = Objects.requireNonNull(parent); + this.block = Objects.requireNonNull(block); + } + + public IBlock getBlock() { + return block; + } + + public IContainer getParent() { + return parent; + } + + @Override + public String toString() { + return "BlockParentContainer{" + block + ", parent=" + parent + '}'; + } +} diff --git a/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java b/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java index 18e5ea30f..9d16e8040 100644 --- a/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java @@ -3,6 +3,7 @@ package jadx.core.utils; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Set; import java.util.function.Consumer; import java.util.function.Predicate; @@ -153,6 +154,40 @@ public class RegionUtils { } } + public static BlockInsnPair getLastInsnWithBlock(IContainer container) { + if (container instanceof IBlock) { + IBlock block = (IBlock) container; + InsnNode lastInsn = ListUtils.last(block.getInstructions()); + if (lastInsn == null) { + return null; + } + return new BlockInsnPair(block, lastInsn); + } + if (container instanceof IBranchRegion) { + List branches = ((IBranchRegion) container).getBranches(); + long count = branches.stream().filter(Objects::nonNull).count(); + if (count == 1) { + // single branch + for (IContainer branch : branches) { + if (branch != null) { + return getLastInsnWithBlock(branch); + } + } + } + // several last instructions + return null; + } + if (container instanceof IRegion) { + IRegion region = (IRegion) container; + List blocks = region.getSubBlocks(); + if (blocks.isEmpty()) { + return null; + } + return getLastInsnWithBlock(ListUtils.last(blocks)); + } + throw new JadxRuntimeException(unknownContainerType(container)); + } + public static IBlock getLastBlock(IContainer container) { if (container instanceof IBlock) { return (IBlock) container; @@ -439,10 +474,11 @@ public class RegionUtils { return true; } - public static IContainer getBlockContainer(IContainer container, BlockNode block) { + public static IContainer getBlockContainer(IContainer container, IBlock block) { if (container instanceof IBlock) { return container == block ? container : null; - } else if (container instanceof IRegion) { + } + if (container instanceof IRegion) { IRegion region = (IRegion) container; for (IContainer c : region.getSubBlocks()) { IContainer res = getBlockContainer(c, block); @@ -451,9 +487,8 @@ public class RegionUtils { } } return null; - } else { - throw new JadxRuntimeException(unknownContainerType(container)); } + throw new JadxRuntimeException(unknownContainerType(container)); } /** diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak3.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak3.java new file mode 100644 index 000000000..cf9972290 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak3.java @@ -0,0 +1,54 @@ +package jadx.tests.integration.switches; + +import jadx.tests.api.IntegrationTest; +import jadx.tests.api.extensions.profiles.TestProfile; +import jadx.tests.api.extensions.profiles.TestWithProfiles; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestSwitchBreak3 extends IntegrationTest { + + @SuppressWarnings("SwitchStatementWithTooFewBranches") + public static class TestCls { + private int value; + + public void test(int i, boolean b1, boolean b2, boolean b3) { + setValue(-1); + switch (i) { + case 0: + if (b1 == b2) { + setValue(1); + // no break here; + } else if (b1 == b3) { + setValue(2); + // no break here; + } + break; + default: + setValue(0); + break; + } + } + + private void setValue(int value) { + this.value = value; + } + + public void check() { + test(0, true, true, true); + assertThat(value).isEqualTo(1); + test(0, true, false, true); + assertThat(value).isEqualTo(2); + test(1, true, true, true); + assertThat(value).isEqualTo(0); + } + } + + @TestWithProfiles({ TestProfile.JAVA11, TestProfile.D8_J11 }) + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .countString(2, "break;") + .containsOne("} else if ("); + } +}