From 1d5368f5a2542aeb7c59aebbefb2ca35633d809f Mon Sep 17 00:00:00 2001 From: Skylot Date: Sat, 27 Dec 2014 23:16:30 +0300 Subject: [PATCH] core: improve out block detection in switch (issue #38) --- .../java/jadx/core/codegen/RegionGen.java | 2 +- .../dex/visitors/regions/RegionMaker.java | 84 +++++++++++++------ .../java/jadx/core/utils/RegionUtils.java | 13 ++- .../integration/switches/TestSwitchBreak.java | 51 +++++++++++ .../switches/TestSwitchContinue.java | 50 +++++++++++ .../switches/TestSwitchReturnFromCase.java | 53 ++++++++++++ 6 files changed, 221 insertions(+), 32 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak.java create mode 100644 jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchContinue.java create mode 100644 jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchReturnFromCase.java diff --git a/jadx-core/src/main/java/jadx/core/codegen/RegionGen.java b/jadx-core/src/main/java/jadx/core/codegen/RegionGen.java index 4f17dcef3..84f15c15c 100644 --- a/jadx-core/src/main/java/jadx/core/codegen/RegionGen.java +++ b/jadx-core/src/main/java/jadx/core/codegen/RegionGen.java @@ -272,7 +272,7 @@ public class RegionGen extends InsnGen { boolean addBreak = true; if (RegionUtils.notEmpty(c)) { makeRegionIndent(code, c); - if (!RegionUtils.hasExitEdge(c)) { + if (RegionUtils.hasExitEdge(c)) { addBreak = false; } } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java index 6deba82b2..79119d864 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java @@ -653,38 +653,44 @@ public class RegionMaker { assert c != null; blocksMap.put(c, entry.getValue()); } - - BitSet succ = BlockUtils.blocksToBitSet(mth, block.getSuccessors()); - BitSet domsOn = BlockUtils.blocksToBitSet(mth, block.getDominatesOn()); - domsOn.xor(succ); // filter 'out' block - BlockNode defCase = getBlockByOffset(insn.getDefaultCaseOffset(), block.getSuccessors()); if (defCase != null) { blocksMap.remove(defCase); } + LoopInfo loop = mth.getLoopForBlock(block); - int outCount = domsOn.cardinality(); - if (outCount > 1) { - // remove exception handlers - BlockUtils.cleanBitSet(mth, domsOn); - outCount = domsOn.cardinality(); + BitSet outs = new BitSet(mth.getBasicBlocks().size()); + outs.or(block.getDomFrontier()); + for (BlockNode s : block.getCleanSuccessors()) { + outs.or(s.getDomFrontier()); } - if (outCount > 1) { - // filter successors of other blocks + stack.push(sw); + stack.addExits(BlockUtils.bitSetToBlocks(mth, outs)); + + // filter 'out' block + if (outs.cardinality() > 1) { + // remove exception handlers + BlockUtils.cleanBitSet(mth, outs); + } + if (outs.cardinality() > 1) { + // filter loop start and successors of other blocks List blocks = mth.getBasicBlocks(); - for (int i = domsOn.nextSetBit(0); i >= 0; i = domsOn.nextSetBit(i + 1)) { + for (int i = outs.nextSetBit(0); i >= 0; i = outs.nextSetBit(i + 1)) { BlockNode b = blocks.get(i); - for (BlockNode s : b.getCleanSuccessors()) { - domsOn.clear(s.getId()); + if (b.contains(AFlag.LOOP_START)) { + outs.clear(b.getId()); + } else { + for (BlockNode s : b.getCleanSuccessors()) { + outs.clear(s.getId()); + } } } - outCount = domsOn.cardinality(); } - BlockNode out = null; - if (outCount == 1) { - out = mth.getBasicBlocks().get(domsOn.nextSetBit(0)); - } else if (outCount == 0) { + if (loop != null && outs.cardinality() > 1) { + outs.clear(loop.getEnd().getId()); + } + if (outs.cardinality() == 0) { // one or several case blocks are empty, // run expensive algorithm for find 'out' block for (BlockNode maybeOut : block.getSuccessors()) { @@ -696,18 +702,24 @@ public class RegionMaker { } } if (allReached) { - out = maybeOut; + outs.set(maybeOut.getId()); break; } } } - - stack.push(sw); - if (out != null) { + BlockNode out = null; + if (outs.cardinality() == 1) { + out = mth.getBasicBlocks().get(outs.nextSetBit(0)); stack.addExit(out); - } else { - LOG.warn("Can't detect out node for switch block: {} in {}", - block.toString(), mth.toString()); + } else if (loop == null && outs.cardinality() > 1) { + LOG.warn("Can't detect out node for switch block: {} in {}", block, mth); + } + if (loop != null) { + // check if 'continue' must be inserted + BlockNode end = loop.getEnd(); + if (out != end && out != null) { + insertContinueInSwitch(block, out, end); + } } if (!stack.containsExit(defCase)) { @@ -727,6 +739,24 @@ public class RegionMaker { return out; } + private static void insertContinueInSwitch(BlockNode block, BlockNode out, BlockNode end) { + int endId = end.getId(); + for (BlockNode s : block.getCleanSuccessors()) { + if (s.getDomFrontier().get(endId) && s != out) { + // search predecessor of loop end on path from this successor + List list = BlockUtils.collectBlocksDominatedBy(s, s); + for (BlockNode p : end.getPredecessors()) { + if (list.contains(p)) { + if (p.isSynthetic()) { + p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0)); + } + break; + } + } + } + } + } + public void processTryCatchBlocks(MethodNode mth) { Set tcs = new HashSet(); for (ExceptionHandler handler : mth.getExceptionHandlers()) { diff --git a/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java b/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java index aae65d2f6..10c2f8c1a 100644 --- a/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java @@ -1,6 +1,5 @@ package jadx.core.utils; -import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AType; import jadx.core.dex.instructions.InsnType; import jadx.core.dex.nodes.BlockNode; @@ -26,9 +25,15 @@ public class RegionUtils { public static boolean hasExitEdge(IContainer container) { if (container instanceof BlockNode) { - BlockNode block = (BlockNode) container; - return !block.getSuccessors().isEmpty() - && !block.contains(AFlag.RETURN); + InsnNode lastInsn = BlockUtils.getLastInsn((BlockNode) container); + if (lastInsn == null) { + return false; + } + InsnType type = lastInsn.getType(); + return type == InsnType.RETURN + || type == InsnType.CONTINUE + || type == InsnType.BREAK + || type == InsnType.THROW; } else if (container instanceof IRegion) { IRegion region = (IRegion) container; List blocks = region.getSubBlocks(); diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak.java new file mode 100644 index 000000000..f91c716ff --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchBreak.java @@ -0,0 +1,51 @@ +package jadx.tests.integration.switches; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.IntegrationTest; + +import org.junit.Test; + +import static jadx.tests.api.utils.JadxMatchers.containsOne; +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +public class TestSwitchBreak extends IntegrationTest { + + public static class TestCls { + public String test(int a) { + String s = ""; + loop: + while (a > 0) { + switch (a % 4) { + case 1: + s += "1"; + break; + case 3: + case 4: + s += "4"; + break; + case 5: + s += "+"; + break loop; + } + s += "-"; + a--; + } + return s; + } + } + + @Test + public void test() { + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + + assertThat(code, containsString("switch (a % 4) {")); + assertEquals(4, count(code, "case ")); + assertEquals(3, count(code, "break;")); + + // TODO finish break with label from switch + assertThat(code, containsOne("return s + \"+\";")); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchContinue.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchContinue.java new file mode 100644 index 000000000..e6e9649ce --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchContinue.java @@ -0,0 +1,50 @@ +package jadx.tests.integration.switches; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.IntegrationTest; + +import org.junit.Test; + +import static jadx.tests.api.utils.JadxMatchers.containsOne; +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +public class TestSwitchContinue extends IntegrationTest { + + public static class TestCls { + public String test(int a) { + String s = ""; + while (a > 0) { + switch (a % 4) { + case 1: + s += "1"; + break; + case 3: + case 4: + s += "4"; + break; + case 5: + a -= 2; + continue; + } + s += "-"; + a--; + } + return s; + } + } + + @Test + public void test() { + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + + assertThat(code, containsString("switch (a % 4) {")); + assertEquals(4, count(code, "case ")); + assertEquals(2, count(code, "break;")); + + assertThat(code, containsOne("a -= 2;")); + assertThat(code, containsOne("continue;")); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchReturnFromCase.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchReturnFromCase.java new file mode 100644 index 000000000..119c610cd --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchReturnFromCase.java @@ -0,0 +1,53 @@ +package jadx.tests.integration.switches; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.IntegrationTest; + +import org.junit.Test; + +import static jadx.tests.api.utils.JadxMatchers.containsOne; +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +public class TestSwitchReturnFromCase extends IntegrationTest { + + public static class TestCls { + public void test(int a) { + String s = null; + if (a > 1000) { + return; + } + switch (a % 4) { + case 1: + s = "1"; + break; + case 2: + s = "2"; + break; + case 3: + case 4: + s = "4"; + break; + case 5: + return; + } + s = "5"; + } + } + + @Test + public void test() { + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + + assertThat(code, containsString("switch (a % 4) {")); + assertEquals(5, count(code, "case ")); + assertEquals(3, count(code, "break;")); + + assertThat(code, containsOne("s = \"1\";")); + assertThat(code, containsOne("s = \"2\";")); + assertThat(code, containsOne("s = \"4\";")); + assertThat(code, containsOne("s = \"5\";")); + } +}