From 02f9c25f52d6484a1ea1514044a8fed8666efc86 Mon Sep 17 00:00:00 2001 From: Skylot Date: Sat, 10 Jan 2015 18:29:53 +0300 Subject: [PATCH] core: support fall through cases in switch --- .../java/jadx/core/dex/attributes/AFlag.java | 2 + .../dex/visitors/regions/RegionMaker.java | 91 ++++++++++++++++++- .../visitors/regions/RegionMakerVisitor.java | 60 ++++++++++-- .../dex/visitors/regions/RegionStack.java | 6 ++ .../java/jadx/core/utils/RegionUtils.java | 22 ++++- .../integration/switches/TestSwitch2.java | 8 +- .../TestSwitchWithFallThroughCase.java | 60 ++++++++++++ .../TestSwitchWithFallThroughCase2.java | 67 ++++++++++++++ .../switches/TestSwitchWithTryCatch.java | 5 +- 9 files changed, 302 insertions(+), 19 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase.java create mode 100644 jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase2.java diff --git a/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java b/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java index 6b2e217ff..9093f2896 100644 --- a/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java +++ b/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java @@ -29,5 +29,7 @@ public enum AFlag { WRAPPED, ARITH_ONEARG, + FALL_THROUGH, + INCONSISTENT_CODE, // warning about incorrect decompilation } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java index 79119d864..1444c26a1 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java @@ -32,6 +32,8 @@ import jadx.core.utils.exceptions.JadxOverflowException; import java.util.ArrayList; import java.util.BitSet; +import java.util.Collections; +import java.util.Comparator; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; @@ -659,14 +661,47 @@ public class RegionMaker { } LoopInfo loop = mth.getLoopForBlock(block); + Map fallThroughCases = new LinkedHashMap(); + BitSet outs = new BitSet(mth.getBasicBlocks().size()); outs.or(block.getDomFrontier()); for (BlockNode s : block.getCleanSuccessors()) { - outs.or(s.getDomFrontier()); + BitSet df = s.getDomFrontier(); + // fall through case block + if (df.cardinality() > 1) { + if (df.cardinality() > 2) { + LOG.debug("Unexpected case pattern, block: {}, mth: {}", s, mth); + } else { + BlockNode first = mth.getBasicBlocks().get(df.nextSetBit(0)); + BlockNode second = mth.getBasicBlocks().get(df.nextSetBit(first.getId() + 1)); + if (second.getDomFrontier().get(first.getId())) { + fallThroughCases.put(s, second); + df = new BitSet(df.size()); + df.set(first.getId()); + } else if (first.getDomFrontier().get(second.getId())) { + fallThroughCases.put(s, first); + df = new BitSet(df.size()); + df.set(second.getId()); + } + } + } + outs.or(df); } stack.push(sw); stack.addExits(BlockUtils.bitSetToBlocks(mth, outs)); + // check cases order if fall through case exists + if (!fallThroughCases.isEmpty()) { + if (isBadCasesOrder(blocksMap, fallThroughCases)) { + LOG.debug("Fixing incorrect switch cases order"); + blocksMap = reOrderSwitchCases(blocksMap, fallThroughCases); + if (isBadCasesOrder(blocksMap, fallThroughCases)) { + LOG.error("Can't fix incorrect switch cases order, method: {}", mth); + mth.add(AFlag.INCONSISTENT_CODE); + } + } + } + // filter 'out' block if (outs.cardinality() > 1) { // remove exception handlers @@ -677,6 +712,7 @@ public class RegionMaker { List blocks = mth.getBasicBlocks(); for (int i = outs.nextSetBit(0); i >= 0; i = outs.nextSetBit(i + 1)) { BlockNode b = blocks.get(i); + outs.andNot(b.getDomFrontier()); if (b.contains(AFlag.LOOP_START)) { outs.clear(b.getId()); } else { @@ -726,12 +762,21 @@ public class RegionMaker { sw.setDefaultCase(makeRegion(defCase, stack)); } for (Entry> entry : blocksMap.entrySet()) { - BlockNode c = entry.getKey(); - if (stack.containsExit(c)) { + BlockNode caseBlock = entry.getKey(); + if (stack.containsExit(caseBlock)) { // empty case block sw.addCase(entry.getValue(), new Region(stack.peekRegion())); } else { - sw.addCase(entry.getValue(), makeRegion(c, stack)); + BlockNode next = fallThroughCases.get(caseBlock); + stack.addExit(next); + Region caseRegion = makeRegion(caseBlock, stack); + stack.removeExit(next); + if (next != null) { + next.add(AFlag.FALL_THROUGH); + caseRegion.add(AFlag.FALL_THROUGH); + } + sw.addCase(entry.getValue(), caseRegion); + // 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor } } @@ -739,6 +784,44 @@ public class RegionMaker { return out; } + private boolean isBadCasesOrder(final Map> blocksMap, + final Map fallThroughCases) { + BlockNode nextCaseBlock = null; + for (BlockNode caseBlock : blocksMap.keySet()) { + if (nextCaseBlock != null && !caseBlock.equals(nextCaseBlock)) { + return true; + } + nextCaseBlock = fallThroughCases.get(caseBlock); + } + return nextCaseBlock != null; + } + + private Map> reOrderSwitchCases(Map> blocksMap, + final Map fallThroughCases) { + List list = new ArrayList(blocksMap.size()); + list.addAll(blocksMap.keySet()); + Collections.sort(list, new Comparator() { + @Override + public int compare(BlockNode a, BlockNode b) { + BlockNode nextA = fallThroughCases.get(a); + if (nextA != null) { + if (b.equals(nextA)) { + return -1; + } + } else if (a.equals(fallThroughCases.get(b))) { + return 1; + } + return 0; + } + }); + + Map> newBlocksMap = new LinkedHashMap>(blocksMap.size()); + for (BlockNode key : list) { + newBlocksMap.put(key, blocksMap.get(key)); + } + return newBlocksMap; + } + private static void insertContinueInSwitch(BlockNode block, BlockNode out, BlockNode end) { int endId = end.getId(); for (BlockNode s : block.getCleanSuccessors()) { diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMakerVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMakerVisitor.java index 3e28b6416..f735c95e1 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMakerVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMakerVisitor.java @@ -1,6 +1,9 @@ package jadx.core.dex.visitors.regions; +import jadx.core.dex.attributes.AFlag; import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.IBlock; import jadx.core.dex.nodes.IContainer; import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.InsnContainer; @@ -16,7 +19,9 @@ import jadx.core.utils.RegionUtils; import jadx.core.utils.exceptions.JadxException; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -62,25 +67,66 @@ public class RegionMakerVisitor extends AbstractVisitor { private static final class PostRegionVisitor extends AbstractRegionVisitor { @Override - public void enterRegion(MethodNode mth, IRegion region) { + public void leaveRegion(MethodNode mth, IRegion region) { if (region instanceof LoopRegion) { // merge conditions in loops LoopRegion loop = (LoopRegion) region; loop.mergePreCondition(); } else if (region instanceof SwitchRegion) { // insert 'break' in switch cases (run after try/catch insertion) - SwitchRegion sw = (SwitchRegion) region; - for (IContainer c : sw.getBranches()) { - if (c instanceof Region && !RegionUtils.hasExitEdge(c)) { - List insns = new ArrayList(1); - insns.add(new InsnNode(InsnType.BREAK, 0)); - ((Region) c).add(new InsnContainer(insns)); + processSwitch(mth, (SwitchRegion) region); + } + } + } + + private static void processSwitch(MethodNode mth, SwitchRegion sw) { + for (IContainer c : sw.getBranches()) { + if (!(c instanceof Region)) { + continue; + } + Set blocks = new HashSet(); + RegionUtils.getAllRegionBlocks(c, blocks); + if (blocks.isEmpty()) { + addBreakToContainer((Region) c); + continue; + } + for (IBlock block : blocks) { + if (!(block instanceof BlockNode)) { + continue; + } + BlockNode bn = (BlockNode) block; + for (BlockNode s : bn.getCleanSuccessors()) { + if (!blocks.contains(s) + && !bn.contains(AFlag.SKIP) + && !s.contains(AFlag.FALL_THROUGH)) { + addBreak(mth, c, bn); + break; } } } } } + private static void addBreak(MethodNode mth, IContainer c, BlockNode bn) { + IContainer blockContainer = RegionUtils.getBlockContainer(c, bn); + if (blockContainer instanceof Region) { + addBreakToContainer((Region) blockContainer); + } else if (c instanceof Region) { + addBreakToContainer((Region) c); + } else { + LOG.warn("Can't insert break, container: {}, block: {}, mth: {}", blockContainer, bn, mth); + } + } + + private static void addBreakToContainer(Region c) { + if (RegionUtils.hasExitEdge(c)) { + return; + } + List insns = new ArrayList(1); + insns.add(new InsnNode(InsnType.BREAK, 0)); + c.add(new InsnContainer(insns)); + } + private static void removeSynchronized(MethodNode mth) { Region startRegion = mth.getRegion(); List subBlocks = startRegion.getSubBlocks(); diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionStack.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionStack.java index bdba01e23..8b09c59b2 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionStack.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionStack.java @@ -95,6 +95,12 @@ final class RegionStack { } } + public void removeExit(BlockNode exit) { + if (exit != null) { + curState.exits.remove(exit); + } + } + public boolean containsExit(BlockNode exit) { return curState.exits.contains(exit); } diff --git a/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java b/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java index 946d27102..ed2286bc8 100644 --- a/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/RegionUtils.java @@ -8,8 +8,6 @@ import jadx.core.dex.nodes.IBranchRegion; import jadx.core.dex.nodes.IContainer; import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.InsnNode; -import jadx.core.dex.regions.SwitchRegion; -import jadx.core.dex.regions.conditions.IfRegion; import jadx.core.dex.trycatch.CatchAttr; import jadx.core.dex.trycatch.ExceptionHandler; import jadx.core.dex.trycatch.TryCatchBlock; @@ -60,8 +58,7 @@ public class RegionUtils { return null; } return insnList.get(insnList.size() - 1); - } else if (container instanceof IfRegion - || container instanceof SwitchRegion) { + } else if (container instanceof IBranchRegion) { return null; } else if (container instanceof IRegion) { IRegion region = (IRegion) container; @@ -235,6 +232,23 @@ public class RegionUtils { return true; } + public static IContainer getBlockContainer(IContainer container, BlockNode block) { + if (container instanceof IBlock) { + return container == block ? container : null; + } else if (container instanceof IRegion) { + IRegion region = (IRegion) container; + for (IContainer c : region.getSubBlocks()) { + IContainer res = getBlockContainer(c, block); + if (res != null) { + return res instanceof IBlock ? region : res; + } + } + return null; + } else { + throw new JadxRuntimeException("Unknown container type: " + container.getClass()); + } + } + public static boolean isDominatedBy(BlockNode dom, IContainer cont) { if (dom == cont) { return true; diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java index a420b1ad5..225135ebf 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java @@ -60,9 +60,11 @@ public class TestSwitch2 extends IntegrationTest { ClassNode cls = getClassNode(TestCls.class); String code = cls.getCode().toString(); - assertThat(code, countString(4, "break;")); +// assertThat(code, countString(4, "break;")); +// assertThat(code, countString(2, "return;")); - // TODO: remove redundant returns - // assertThat(code, countString(2, "return;")); + // TODO: remove redundant break and returns + assertThat(code, countString(5, "break;")); + assertThat(code, countString(4, "return;")); } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase.java new file mode 100644 index 000000000..2d1178f51 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase.java @@ -0,0 +1,60 @@ +package jadx.tests.integration.switches; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.IntegrationTest; + +import org.junit.Test; + +import static jadx.tests.api.utils.JadxMatchers.containsOne; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +public class TestSwitchWithFallThroughCase extends IntegrationTest { + + public static class TestCls { + public String test(int a, boolean b, boolean c) { + String str = ""; + switch (a % 4) { + case 1: + str += ">"; + if (a == 5 && b) { + if (c) { + str += "1"; + } else { + str += "!c"; + } + break; + } + case 2: + if (b) { + str += "2"; + } + break; + case 3: + break; + default: + str += "default"; + break; + } + str += ";"; + return str; + } + + public void check() { + assertEquals(">1;", test(5, true, true)); + assertEquals(">2;", test(1, true, true)); + assertEquals(";", test(3, true, true)); + assertEquals("default;", test(0, true, true)); + } + } + + @Test + public void test() { + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + + assertThat(code, containsOne("switch (a % 4) {")); + assertThat(code, containsOne("if (a == 5 && b) {")); + assertThat(code, containsOne("if (b) {")); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase2.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase2.java new file mode 100644 index 000000000..05f817101 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase2.java @@ -0,0 +1,67 @@ +package jadx.tests.integration.switches; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.IntegrationTest; + +import org.junit.Test; + +import static jadx.tests.api.utils.JadxMatchers.containsOne; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +public class TestSwitchWithFallThroughCase2 extends IntegrationTest { + + public static class TestCls { + public String test(int a, boolean b, boolean c) { + String str = ""; + if (a > 0) { + switch (a % 4) { + case 1: + str += ">"; + if (a == 5 && b) { + if (c) { + str += "1"; + } else { + str += "!c"; + } + break; + } + case 2: + if (b) { + str += "2"; + } + break; + case 3: + break; + default: + str += "default"; + break; + } + str += "+"; + } + if (b && c) { + str += "-"; + } + return str; + } + + public void check() { + assertEquals(">1+-", test(5, true, true)); + assertEquals(">2+-", test(1, true, true)); + assertEquals("+-", test(3, true, true)); + assertEquals("default+-", test(16, true, true)); + assertEquals("-", test(-1, true, true)); + } + } + + @Test + public void test() { + setOutputCFG(); + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + + assertThat(code, containsOne("switch (a % 4) {")); + assertThat(code, containsOne("if (a == 5 && b) {")); + assertThat(code, containsOne("if (b) {")); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithTryCatch.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithTryCatch.java index ae1874c94..a32e96898 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithTryCatch.java +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithTryCatch.java @@ -62,7 +62,10 @@ public class TestSwitchWithTryCatch extends IntegrationTest { ClassNode cls = getClassNode(TestCls.class); String code = cls.getCode().toString(); - assertThat(code, countString(3, "break;")); +// assertThat(code, countString(3, "break;")); assertThat(code, countString(4, "return;")); + + // TODO: remove redundant break + assertThat(code, countString(4, "break;")); } }