core: support fall through cases in switch
This commit is contained in:
@@ -29,5 +29,7 @@ public enum AFlag {
|
||||
WRAPPED,
|
||||
ARITH_ONEARG,
|
||||
|
||||
FALL_THROUGH,
|
||||
|
||||
INCONSISTENT_CODE, // warning about incorrect decompilation
|
||||
}
|
||||
|
||||
@@ -32,6 +32,8 @@ import jadx.core.utils.exceptions.JadxOverflowException;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.BitSet;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
@@ -659,14 +661,47 @@ public class RegionMaker {
|
||||
}
|
||||
LoopInfo loop = mth.getLoopForBlock(block);
|
||||
|
||||
Map<BlockNode, BlockNode> fallThroughCases = new LinkedHashMap<BlockNode, BlockNode>();
|
||||
|
||||
BitSet outs = new BitSet(mth.getBasicBlocks().size());
|
||||
outs.or(block.getDomFrontier());
|
||||
for (BlockNode s : block.getCleanSuccessors()) {
|
||||
outs.or(s.getDomFrontier());
|
||||
BitSet df = s.getDomFrontier();
|
||||
// fall through case block
|
||||
if (df.cardinality() > 1) {
|
||||
if (df.cardinality() > 2) {
|
||||
LOG.debug("Unexpected case pattern, block: {}, mth: {}", s, mth);
|
||||
} else {
|
||||
BlockNode first = mth.getBasicBlocks().get(df.nextSetBit(0));
|
||||
BlockNode second = mth.getBasicBlocks().get(df.nextSetBit(first.getId() + 1));
|
||||
if (second.getDomFrontier().get(first.getId())) {
|
||||
fallThroughCases.put(s, second);
|
||||
df = new BitSet(df.size());
|
||||
df.set(first.getId());
|
||||
} else if (first.getDomFrontier().get(second.getId())) {
|
||||
fallThroughCases.put(s, first);
|
||||
df = new BitSet(df.size());
|
||||
df.set(second.getId());
|
||||
}
|
||||
}
|
||||
}
|
||||
outs.or(df);
|
||||
}
|
||||
stack.push(sw);
|
||||
stack.addExits(BlockUtils.bitSetToBlocks(mth, outs));
|
||||
|
||||
// check cases order if fall through case exists
|
||||
if (!fallThroughCases.isEmpty()) {
|
||||
if (isBadCasesOrder(blocksMap, fallThroughCases)) {
|
||||
LOG.debug("Fixing incorrect switch cases order");
|
||||
blocksMap = reOrderSwitchCases(blocksMap, fallThroughCases);
|
||||
if (isBadCasesOrder(blocksMap, fallThroughCases)) {
|
||||
LOG.error("Can't fix incorrect switch cases order, method: {}", mth);
|
||||
mth.add(AFlag.INCONSISTENT_CODE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filter 'out' block
|
||||
if (outs.cardinality() > 1) {
|
||||
// remove exception handlers
|
||||
@@ -677,6 +712,7 @@ public class RegionMaker {
|
||||
List<BlockNode> blocks = mth.getBasicBlocks();
|
||||
for (int i = outs.nextSetBit(0); i >= 0; i = outs.nextSetBit(i + 1)) {
|
||||
BlockNode b = blocks.get(i);
|
||||
outs.andNot(b.getDomFrontier());
|
||||
if (b.contains(AFlag.LOOP_START)) {
|
||||
outs.clear(b.getId());
|
||||
} else {
|
||||
@@ -726,12 +762,21 @@ public class RegionMaker {
|
||||
sw.setDefaultCase(makeRegion(defCase, stack));
|
||||
}
|
||||
for (Entry<BlockNode, List<Object>> entry : blocksMap.entrySet()) {
|
||||
BlockNode c = entry.getKey();
|
||||
if (stack.containsExit(c)) {
|
||||
BlockNode caseBlock = entry.getKey();
|
||||
if (stack.containsExit(caseBlock)) {
|
||||
// empty case block
|
||||
sw.addCase(entry.getValue(), new Region(stack.peekRegion()));
|
||||
} else {
|
||||
sw.addCase(entry.getValue(), makeRegion(c, stack));
|
||||
BlockNode next = fallThroughCases.get(caseBlock);
|
||||
stack.addExit(next);
|
||||
Region caseRegion = makeRegion(caseBlock, stack);
|
||||
stack.removeExit(next);
|
||||
if (next != null) {
|
||||
next.add(AFlag.FALL_THROUGH);
|
||||
caseRegion.add(AFlag.FALL_THROUGH);
|
||||
}
|
||||
sw.addCase(entry.getValue(), caseRegion);
|
||||
// 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor
|
||||
}
|
||||
}
|
||||
|
||||
@@ -739,6 +784,44 @@ public class RegionMaker {
|
||||
return out;
|
||||
}
|
||||
|
||||
private boolean isBadCasesOrder(final Map<BlockNode, List<Object>> blocksMap,
|
||||
final Map<BlockNode, BlockNode> fallThroughCases) {
|
||||
BlockNode nextCaseBlock = null;
|
||||
for (BlockNode caseBlock : blocksMap.keySet()) {
|
||||
if (nextCaseBlock != null && !caseBlock.equals(nextCaseBlock)) {
|
||||
return true;
|
||||
}
|
||||
nextCaseBlock = fallThroughCases.get(caseBlock);
|
||||
}
|
||||
return nextCaseBlock != null;
|
||||
}
|
||||
|
||||
private Map<BlockNode, List<Object>> reOrderSwitchCases(Map<BlockNode, List<Object>> blocksMap,
|
||||
final Map<BlockNode, BlockNode> fallThroughCases) {
|
||||
List<BlockNode> list = new ArrayList<BlockNode>(blocksMap.size());
|
||||
list.addAll(blocksMap.keySet());
|
||||
Collections.sort(list, new Comparator<BlockNode>() {
|
||||
@Override
|
||||
public int compare(BlockNode a, BlockNode b) {
|
||||
BlockNode nextA = fallThroughCases.get(a);
|
||||
if (nextA != null) {
|
||||
if (b.equals(nextA)) {
|
||||
return -1;
|
||||
}
|
||||
} else if (a.equals(fallThroughCases.get(b))) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
});
|
||||
|
||||
Map<BlockNode, List<Object>> newBlocksMap = new LinkedHashMap<BlockNode, List<Object>>(blocksMap.size());
|
||||
for (BlockNode key : list) {
|
||||
newBlocksMap.put(key, blocksMap.get(key));
|
||||
}
|
||||
return newBlocksMap;
|
||||
}
|
||||
|
||||
private static void insertContinueInSwitch(BlockNode block, BlockNode out, BlockNode end) {
|
||||
int endId = end.getId();
|
||||
for (BlockNode s : block.getCleanSuccessors()) {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package jadx.core.dex.visitors.regions;
|
||||
|
||||
import jadx.core.dex.attributes.AFlag;
|
||||
import jadx.core.dex.instructions.InsnType;
|
||||
import jadx.core.dex.nodes.BlockNode;
|
||||
import jadx.core.dex.nodes.IBlock;
|
||||
import jadx.core.dex.nodes.IContainer;
|
||||
import jadx.core.dex.nodes.IRegion;
|
||||
import jadx.core.dex.nodes.InsnContainer;
|
||||
@@ -16,7 +19,9 @@ import jadx.core.utils.RegionUtils;
|
||||
import jadx.core.utils.exceptions.JadxException;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -62,25 +67,66 @@ public class RegionMakerVisitor extends AbstractVisitor {
|
||||
|
||||
private static final class PostRegionVisitor extends AbstractRegionVisitor {
|
||||
@Override
|
||||
public void enterRegion(MethodNode mth, IRegion region) {
|
||||
public void leaveRegion(MethodNode mth, IRegion region) {
|
||||
if (region instanceof LoopRegion) {
|
||||
// merge conditions in loops
|
||||
LoopRegion loop = (LoopRegion) region;
|
||||
loop.mergePreCondition();
|
||||
} else if (region instanceof SwitchRegion) {
|
||||
// insert 'break' in switch cases (run after try/catch insertion)
|
||||
SwitchRegion sw = (SwitchRegion) region;
|
||||
for (IContainer c : sw.getBranches()) {
|
||||
if (c instanceof Region && !RegionUtils.hasExitEdge(c)) {
|
||||
List<InsnNode> insns = new ArrayList<InsnNode>(1);
|
||||
insns.add(new InsnNode(InsnType.BREAK, 0));
|
||||
((Region) c).add(new InsnContainer(insns));
|
||||
processSwitch(mth, (SwitchRegion) region);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void processSwitch(MethodNode mth, SwitchRegion sw) {
|
||||
for (IContainer c : sw.getBranches()) {
|
||||
if (!(c instanceof Region)) {
|
||||
continue;
|
||||
}
|
||||
Set<IBlock> blocks = new HashSet<IBlock>();
|
||||
RegionUtils.getAllRegionBlocks(c, blocks);
|
||||
if (blocks.isEmpty()) {
|
||||
addBreakToContainer((Region) c);
|
||||
continue;
|
||||
}
|
||||
for (IBlock block : blocks) {
|
||||
if (!(block instanceof BlockNode)) {
|
||||
continue;
|
||||
}
|
||||
BlockNode bn = (BlockNode) block;
|
||||
for (BlockNode s : bn.getCleanSuccessors()) {
|
||||
if (!blocks.contains(s)
|
||||
&& !bn.contains(AFlag.SKIP)
|
||||
&& !s.contains(AFlag.FALL_THROUGH)) {
|
||||
addBreak(mth, c, bn);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void addBreak(MethodNode mth, IContainer c, BlockNode bn) {
|
||||
IContainer blockContainer = RegionUtils.getBlockContainer(c, bn);
|
||||
if (blockContainer instanceof Region) {
|
||||
addBreakToContainer((Region) blockContainer);
|
||||
} else if (c instanceof Region) {
|
||||
addBreakToContainer((Region) c);
|
||||
} else {
|
||||
LOG.warn("Can't insert break, container: {}, block: {}, mth: {}", blockContainer, bn, mth);
|
||||
}
|
||||
}
|
||||
|
||||
private static void addBreakToContainer(Region c) {
|
||||
if (RegionUtils.hasExitEdge(c)) {
|
||||
return;
|
||||
}
|
||||
List<InsnNode> insns = new ArrayList<InsnNode>(1);
|
||||
insns.add(new InsnNode(InsnType.BREAK, 0));
|
||||
c.add(new InsnContainer(insns));
|
||||
}
|
||||
|
||||
private static void removeSynchronized(MethodNode mth) {
|
||||
Region startRegion = mth.getRegion();
|
||||
List<IContainer> subBlocks = startRegion.getSubBlocks();
|
||||
|
||||
@@ -95,6 +95,12 @@ final class RegionStack {
|
||||
}
|
||||
}
|
||||
|
||||
public void removeExit(BlockNode exit) {
|
||||
if (exit != null) {
|
||||
curState.exits.remove(exit);
|
||||
}
|
||||
}
|
||||
|
||||
public boolean containsExit(BlockNode exit) {
|
||||
return curState.exits.contains(exit);
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import jadx.core.dex.nodes.IBranchRegion;
|
||||
import jadx.core.dex.nodes.IContainer;
|
||||
import jadx.core.dex.nodes.IRegion;
|
||||
import jadx.core.dex.nodes.InsnNode;
|
||||
import jadx.core.dex.regions.SwitchRegion;
|
||||
import jadx.core.dex.regions.conditions.IfRegion;
|
||||
import jadx.core.dex.trycatch.CatchAttr;
|
||||
import jadx.core.dex.trycatch.ExceptionHandler;
|
||||
import jadx.core.dex.trycatch.TryCatchBlock;
|
||||
@@ -60,8 +58,7 @@ public class RegionUtils {
|
||||
return null;
|
||||
}
|
||||
return insnList.get(insnList.size() - 1);
|
||||
} else if (container instanceof IfRegion
|
||||
|| container instanceof SwitchRegion) {
|
||||
} else if (container instanceof IBranchRegion) {
|
||||
return null;
|
||||
} else if (container instanceof IRegion) {
|
||||
IRegion region = (IRegion) container;
|
||||
@@ -235,6 +232,23 @@ public class RegionUtils {
|
||||
return true;
|
||||
}
|
||||
|
||||
public static IContainer getBlockContainer(IContainer container, BlockNode block) {
|
||||
if (container instanceof IBlock) {
|
||||
return container == block ? container : null;
|
||||
} else if (container instanceof IRegion) {
|
||||
IRegion region = (IRegion) container;
|
||||
for (IContainer c : region.getSubBlocks()) {
|
||||
IContainer res = getBlockContainer(c, block);
|
||||
if (res != null) {
|
||||
return res instanceof IBlock ? region : res;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
} else {
|
||||
throw new JadxRuntimeException("Unknown container type: " + container.getClass());
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean isDominatedBy(BlockNode dom, IContainer cont) {
|
||||
if (dom == cont) {
|
||||
return true;
|
||||
|
||||
@@ -60,9 +60,11 @@ public class TestSwitch2 extends IntegrationTest {
|
||||
ClassNode cls = getClassNode(TestCls.class);
|
||||
String code = cls.getCode().toString();
|
||||
|
||||
assertThat(code, countString(4, "break;"));
|
||||
// assertThat(code, countString(4, "break;"));
|
||||
// assertThat(code, countString(2, "return;"));
|
||||
|
||||
// TODO: remove redundant returns
|
||||
// assertThat(code, countString(2, "return;"));
|
||||
// TODO: remove redundant break and returns
|
||||
assertThat(code, countString(5, "break;"));
|
||||
assertThat(code, countString(4, "return;"));
|
||||
}
|
||||
}
|
||||
|
||||
+60
@@ -0,0 +1,60 @@
|
||||
package jadx.tests.integration.switches;
|
||||
|
||||
import jadx.core.dex.nodes.ClassNode;
|
||||
import jadx.tests.api.IntegrationTest;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import static jadx.tests.api.utils.JadxMatchers.containsOne;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertThat;
|
||||
|
||||
public class TestSwitchWithFallThroughCase extends IntegrationTest {
|
||||
|
||||
public static class TestCls {
|
||||
public String test(int a, boolean b, boolean c) {
|
||||
String str = "";
|
||||
switch (a % 4) {
|
||||
case 1:
|
||||
str += ">";
|
||||
if (a == 5 && b) {
|
||||
if (c) {
|
||||
str += "1";
|
||||
} else {
|
||||
str += "!c";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
if (b) {
|
||||
str += "2";
|
||||
}
|
||||
break;
|
||||
case 3:
|
||||
break;
|
||||
default:
|
||||
str += "default";
|
||||
break;
|
||||
}
|
||||
str += ";";
|
||||
return str;
|
||||
}
|
||||
|
||||
public void check() {
|
||||
assertEquals(">1;", test(5, true, true));
|
||||
assertEquals(">2;", test(1, true, true));
|
||||
assertEquals(";", test(3, true, true));
|
||||
assertEquals("default;", test(0, true, true));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test() {
|
||||
ClassNode cls = getClassNode(TestCls.class);
|
||||
String code = cls.getCode().toString();
|
||||
|
||||
assertThat(code, containsOne("switch (a % 4) {"));
|
||||
assertThat(code, containsOne("if (a == 5 && b) {"));
|
||||
assertThat(code, containsOne("if (b) {"));
|
||||
}
|
||||
}
|
||||
+67
@@ -0,0 +1,67 @@
|
||||
package jadx.tests.integration.switches;
|
||||
|
||||
import jadx.core.dex.nodes.ClassNode;
|
||||
import jadx.tests.api.IntegrationTest;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import static jadx.tests.api.utils.JadxMatchers.containsOne;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertThat;
|
||||
|
||||
public class TestSwitchWithFallThroughCase2 extends IntegrationTest {
|
||||
|
||||
public static class TestCls {
|
||||
public String test(int a, boolean b, boolean c) {
|
||||
String str = "";
|
||||
if (a > 0) {
|
||||
switch (a % 4) {
|
||||
case 1:
|
||||
str += ">";
|
||||
if (a == 5 && b) {
|
||||
if (c) {
|
||||
str += "1";
|
||||
} else {
|
||||
str += "!c";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
if (b) {
|
||||
str += "2";
|
||||
}
|
||||
break;
|
||||
case 3:
|
||||
break;
|
||||
default:
|
||||
str += "default";
|
||||
break;
|
||||
}
|
||||
str += "+";
|
||||
}
|
||||
if (b && c) {
|
||||
str += "-";
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
public void check() {
|
||||
assertEquals(">1+-", test(5, true, true));
|
||||
assertEquals(">2+-", test(1, true, true));
|
||||
assertEquals("+-", test(3, true, true));
|
||||
assertEquals("default+-", test(16, true, true));
|
||||
assertEquals("-", test(-1, true, true));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test() {
|
||||
setOutputCFG();
|
||||
ClassNode cls = getClassNode(TestCls.class);
|
||||
String code = cls.getCode().toString();
|
||||
|
||||
assertThat(code, containsOne("switch (a % 4) {"));
|
||||
assertThat(code, containsOne("if (a == 5 && b) {"));
|
||||
assertThat(code, containsOne("if (b) {"));
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,10 @@ public class TestSwitchWithTryCatch extends IntegrationTest {
|
||||
ClassNode cls = getClassNode(TestCls.class);
|
||||
String code = cls.getCode().toString();
|
||||
|
||||
assertThat(code, countString(3, "break;"));
|
||||
// assertThat(code, countString(3, "break;"));
|
||||
assertThat(code, countString(4, "return;"));
|
||||
|
||||
// TODO: remove redundant break
|
||||
assertThat(code, countString(4, "break;"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user