From 7abbc8188687de6a5c87b3f989368dd9b0ef44c0 Mon Sep 17 00:00:00 2001 From: Skylot <118523+skylot@users.noreply.github.com> Date: Sun, 22 Sep 2024 21:09:10 +0100 Subject: [PATCH] fix: improve switch out block search if all method exits are inside (#2264) --- .../regions/maker/SwitchRegionMaker.java | 69 +++++++++++++++++-- .../integration/switches/TestSwitch3.java | 4 +- .../integration/switches/TestSwitch4.java | 40 +++++++++++ 3 files changed, 106 insertions(+), 7 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch4.java diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java index b6620c1c9..76b746a92 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/maker/SwitchRegionMaker.java @@ -15,6 +15,9 @@ import jadx.core.dex.attributes.nodes.LoopInfo; import jadx.core.dex.attributes.nodes.RegionRefAttr; import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.SwitchInsn; +import jadx.core.dex.instructions.args.ArgType; +import jadx.core.dex.instructions.args.InsnArg; +import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.InsnNode; @@ -132,11 +135,6 @@ final class SwitchRegionMaker { } outs.clear(block.getId()); outs.clear(mth.getExitBlock().getId()); - if (outs.isEmpty()) { - // switch already contains method exit - // add everything, out block not needed - return mth.getExitBlock(); - } BlockNode out = null; if (outs.cardinality() == 1) { @@ -161,6 +159,10 @@ final class SwitchRegionMaker { out = possibleOut; } } + if (outs.isEmpty()) { + // all exits inside switch, keep inside to exit from loop + return mth.getExitBlock(); + } } if (out == null) { BlockNode imPostDom = block.getIPostDom(); @@ -177,6 +179,11 @@ final class SwitchRegionMaker { out = mth.getExitBlock(); } BlockNode imPostDom = block.getIPostDom(); + if (out == null && imPostDom == mth.getExitBlock()) { + // all exits inside switch + // check if all returns are equals and should be treated as single out block + return allSameReturns(stack); + } if (out != imPostDom && !mth.isPreExitBlock(imPostDom)) { // stop other paths at common exit stack.addExit(imPostDom); @@ -197,6 +204,58 @@ final class SwitchRegionMaker { return out; } + private BlockNode allSameReturns(RegionStack stack) { + BlockNode exitBlock = mth.getExitBlock(); + List preds = exitBlock.getPredecessors(); + int count = preds.size(); + if (count == 1) { + return preds.get(0); + } + if (mth.getReturnType() == ArgType.VOID) { + for (BlockNode pred : preds) { + InsnNode insn = BlockUtils.getLastInsn(pred); + if (insn == null || insn.getType() != InsnType.RETURN) { + return exitBlock; + } + } + } else { + List returnArgs = new ArrayList<>(); + for (BlockNode pred : preds) { + InsnNode insn = BlockUtils.getLastInsn(pred); + if (insn == null || insn.getType() != InsnType.RETURN) { + return exitBlock; + } + returnArgs.add(insn.getArg(0)); + } + InsnArg firstArg = returnArgs.get(0); + if (firstArg.isRegister()) { + RegisterArg reg = (RegisterArg) firstArg; + for (int i = 1; i < count; i++) { + InsnArg arg = returnArgs.get(1); + if (!arg.isRegister() || !((RegisterArg) arg).sameCodeVar(reg)) { + return exitBlock; + } + } + } else { + for (int i = 1; i < count; i++) { + InsnArg arg = returnArgs.get(1); + if (!arg.equals(firstArg)) { + return exitBlock; + } + } + } + } + // confirmed + stack.addExits(preds); + // ignore other returns + for (int i = 1; i < count; i++) { + BlockNode block = preds.get(i); + block.add(AFlag.REMOVE); + block.add(AFlag.ADDED_TO_REGION); + } + return preds.get(0); + } + /** * Remove empty case blocks: * 1. single 'default' case diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch3.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch3.java index 7ec100cd4..c31475e50 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch3.java +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch3.java @@ -45,7 +45,7 @@ public class TestSwitch3 extends IntegrationTest { public void test() { assertThat(getClassNode(TestCls.class)) .code() - .countString(0, "break;") - .countString(3, "return;"); + .countString(3, "break;") + .countString(0, "return;"); } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch4.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch4.java new file mode 100644 index 000000000..a98de4adb --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch4.java @@ -0,0 +1,40 @@ +package jadx.tests.integration.switches; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestSwitch4 extends IntegrationTest { + + public static class TestCls { + @SuppressWarnings({ "FallThrough", "unused" }) + private static int parse(char[] ch, int off, int len) { + int num = ch[off + len - 1] - '0'; + switch (len) { + case 4: + num += (ch[off++] - '0') * 1000; + case 3: + num += (ch[off++] - '0') * 100; + case 2: + num += (ch[off] - '0') * 10; + } + return num; + } + + public void check() { + assertThat(parse("123".toCharArray(), 0, 3)).isEqualTo(123); + assertThat(parse("a=1234".toCharArray(), 2, 4)).isEqualTo(1234); + } + } + + @Test + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .containsOne("switch (") + .countString(3, "case ") + .doesNotContain("break"); + } +}