fix: optimize switch fallthrough (PR #2054)
* cache post dom map between switch cases * cache post dom map of whole methods * calculate full post dom tree, fix switch out block detection --------- Co-authored-by: Skylot <skylot@gmail.com>
This commit is contained in:
@@ -56,7 +56,7 @@ public class SimpleModeHelper {
|
||||
if (!block.contains(AFlag.EXC_BOTTOM_SPLITTER)) {
|
||||
startLabel.set(block.getId());
|
||||
}
|
||||
if (prev.getSuccessors().size() == 1 && !mth.isPreExitBlocks(prev)) {
|
||||
if (prev.getSuccessors().size() == 1 && !mth.isPreExitBlock(prev)) {
|
||||
endGoto.set(prev.getId());
|
||||
}
|
||||
}
|
||||
@@ -68,7 +68,7 @@ public class SimpleModeHelper {
|
||||
if (block.contains(AType.EXC_HANDLER)) {
|
||||
startLabel.set(block.getId());
|
||||
}
|
||||
if (nextBlock == null && !mth.isPreExitBlocks(block)) {
|
||||
if (nextBlock == null && !mth.isPreExitBlock(block)) {
|
||||
endGoto.set(block.getId());
|
||||
}
|
||||
prev = block;
|
||||
@@ -145,7 +145,7 @@ public class SimpleModeHelper {
|
||||
// DFS sort blocks to reduce goto count
|
||||
private List<BlockNode> getSortedBlocks() {
|
||||
List<BlockNode> list = new ArrayList<>(mth.getBasicBlocks().size());
|
||||
BlockUtils.dfsVisit(mth, list::add);
|
||||
BlockUtils.visitDFS(mth, list::add);
|
||||
return list;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,11 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
|
||||
*/
|
||||
private BitSet doms = EmptyBitSet.EMPTY;
|
||||
|
||||
/**
|
||||
* Post dominators, excluding self
|
||||
*/
|
||||
private BitSet postDoms = EmptyBitSet.EMPTY;
|
||||
|
||||
/**
|
||||
* Dominance frontier
|
||||
*/
|
||||
@@ -56,6 +61,11 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
|
||||
*/
|
||||
private BlockNode idom;
|
||||
|
||||
/**
|
||||
* Immediate post dominator
|
||||
*/
|
||||
private BlockNode iPostDom;
|
||||
|
||||
/**
|
||||
* Blocks on which dominates this block
|
||||
*/
|
||||
@@ -165,6 +175,14 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
|
||||
this.doms = doms;
|
||||
}
|
||||
|
||||
public BitSet getPostDoms() {
|
||||
return postDoms;
|
||||
}
|
||||
|
||||
public void setPostDoms(BitSet postDoms) {
|
||||
this.postDoms = postDoms;
|
||||
}
|
||||
|
||||
public BitSet getDomFrontier() {
|
||||
return domFrontier;
|
||||
}
|
||||
@@ -184,6 +202,14 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
|
||||
this.idom = idom;
|
||||
}
|
||||
|
||||
public BlockNode getIPostDom() {
|
||||
return iPostDom;
|
||||
}
|
||||
|
||||
public void setIPostDom(BlockNode iPostDom) {
|
||||
this.iPostDom = iPostDom;
|
||||
}
|
||||
|
||||
public List<BlockNode> getDominatesOn() {
|
||||
return dominatesOn;
|
||||
}
|
||||
@@ -233,6 +259,6 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "B:" + cid + ':' + InsnUtils.formatOffset(startOffset);
|
||||
return "B:" + id + ':' + InsnUtils.formatOffset(startOffset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,10 +360,13 @@ public class MethodNode extends NotificationAttrNode implements IMethodDetails,
|
||||
|
||||
public void setBasicBlocks(List<BlockNode> blocks) {
|
||||
this.blocks = blocks;
|
||||
int i = 0;
|
||||
for (BlockNode block : blocks) {
|
||||
block.setId(i);
|
||||
i++;
|
||||
updateBlockIds(blocks);
|
||||
}
|
||||
|
||||
public void updateBlockIds(List<BlockNode> blocks) {
|
||||
int count = blocks.size();
|
||||
for (int i = 0; i < count; i++) {
|
||||
blocks.get(i).setId(i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,7 +394,7 @@ public class MethodNode extends NotificationAttrNode implements IMethodDetails,
|
||||
return exitBlock.getPredecessors();
|
||||
}
|
||||
|
||||
public boolean isPreExitBlocks(BlockNode block) {
|
||||
public boolean isPreExitBlock(BlockNode block) {
|
||||
List<BlockNode> successors = block.getSuccessors();
|
||||
if (successors.size() == 1) {
|
||||
return successors.get(0).equals(exitBlock);
|
||||
|
||||
@@ -15,7 +15,12 @@ import jadx.core.utils.exceptions.CodegenException;
|
||||
|
||||
public final class SwitchRegion extends AbstractRegion implements IBranchRegion {
|
||||
|
||||
public static final Object DEFAULT_CASE_KEY = new Object();
|
||||
public static final Object DEFAULT_CASE_KEY = new Object() {
|
||||
@Override
|
||||
public String toString() {
|
||||
return "default";
|
||||
}
|
||||
};
|
||||
|
||||
private final BlockNode header;
|
||||
|
||||
|
||||
@@ -199,7 +199,7 @@ public class DotGraphVisitor extends AbstractVisitor {
|
||||
dot.add("color=red,");
|
||||
}
|
||||
dot.add("label=\"{");
|
||||
dot.add(String.valueOf(block.getCId())).add("\\:\\ ");
|
||||
dot.add(String.valueOf(block.getId())).add("\\:\\ ");
|
||||
dot.add(InsnUtils.formatOffset(block.getStartOffset()));
|
||||
if (!attrs.isEmpty()) {
|
||||
dot.add('|').add(attrs);
|
||||
@@ -208,6 +208,8 @@ public class DotGraphVisitor extends AbstractVisitor {
|
||||
dot.add('|');
|
||||
dot.startLine("doms: ").add(escape(block.getDoms()));
|
||||
dot.startLine("\\lidom: ").add(escape(block.getIDom()));
|
||||
dot.startLine("\\lpost-doms: ").add(escape(block.getPostDoms()));
|
||||
dot.startLine("\\lpost-idom: ").add(escape(block.getIPostDom()));
|
||||
dot.startLine("\\ldom-f: ").add(escape(block.getDomFrontier()));
|
||||
dot.startLine("\\ldoms-on: ").add(escape(Utils.listToString(block.getDominatesOn())));
|
||||
dot.startLine("\\l");
|
||||
@@ -230,10 +232,10 @@ public class DotGraphVisitor extends AbstractVisitor {
|
||||
|
||||
if (PRINT_DOMINATORS) {
|
||||
for (BlockNode c : block.getDominatesOn()) {
|
||||
conn.startLine(block.getCId() + " -> " + c.getCId() + "[color=green];");
|
||||
conn.startLine(block.getId() + " -> " + c.getId() + "[color=green];");
|
||||
}
|
||||
for (BlockNode dom : BlockUtils.bitSetToBlocks(mth, block.getDomFrontier())) {
|
||||
conn.startLine("f_" + block.getCId() + " -> f_" + dom.getCId() + "[color=blue];");
|
||||
conn.startLine("f_" + block.getId() + " -> f_" + dom.getId() + "[color=blue];");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -273,7 +275,7 @@ public class DotGraphVisitor extends AbstractVisitor {
|
||||
private String makeName(IContainer c) {
|
||||
String name;
|
||||
if (c instanceof BlockNode) {
|
||||
name = "Node_" + ((BlockNode) c).getCId();
|
||||
name = "Node_" + ((BlockNode) c).getId();
|
||||
} else if (c instanceof IBlock) {
|
||||
name = "Node_" + c.getClass().getSimpleName() + '_' + c.hashCode();
|
||||
} else {
|
||||
|
||||
@@ -68,7 +68,7 @@ public class BlockExceptionHandler {
|
||||
BlockProcessor.removeMarkedBlocks(mth);
|
||||
|
||||
BlockSet sorted = new BlockSet(mth);
|
||||
BlockUtils.dfsVisit(mth, sorted::set);
|
||||
BlockUtils.visitDFS(mth, sorted::set);
|
||||
removeUnusedExcHandlers(mth, tryBlocks, sorted);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -70,6 +70,8 @@ public class BlockProcessor extends AbstractVisitor {
|
||||
registerLoops(mth);
|
||||
processNestedLoops(mth);
|
||||
|
||||
PostDominatorTree.compute(mth);
|
||||
|
||||
updateCleanSuccessors(mth);
|
||||
if (!mth.contains(AFlag.DISABLE_BLOCKS_LOCK)) {
|
||||
mth.finishBasicBlocks();
|
||||
|
||||
@@ -3,8 +3,7 @@ package jadx.core.dex.visitors.blocks;
|
||||
import java.util.ArrayList;
|
||||
import java.util.BitSet;
|
||||
import java.util.List;
|
||||
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import java.util.function.Function;
|
||||
|
||||
import jadx.core.dex.nodes.BlockNode;
|
||||
import jadx.core.dex.nodes.MethodNode;
|
||||
@@ -23,14 +22,14 @@ public class DominatorTree {
|
||||
|
||||
public static void compute(MethodNode mth) {
|
||||
List<BlockNode> sorted = sortBlocks(mth);
|
||||
BlockNode[] doms = build(sorted);
|
||||
BlockNode[] doms = build(sorted, BlockNode::getPredecessors);
|
||||
apply(sorted, doms);
|
||||
}
|
||||
|
||||
private static List<BlockNode> sortBlocks(MethodNode mth) {
|
||||
int blocksCount = mth.getBasicBlocks().size();
|
||||
List<BlockNode> sorted = new ArrayList<>(blocksCount);
|
||||
BlockUtils.dfsVisit(mth, sorted::add);
|
||||
BlockUtils.visitDFS(mth, sorted::add);
|
||||
if (sorted.size() != blocksCount) {
|
||||
throw new JadxRuntimeException("Found unreachable blocks");
|
||||
}
|
||||
@@ -38,8 +37,7 @@ public class DominatorTree {
|
||||
return sorted;
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static BlockNode[] build(List<BlockNode> sorted) {
|
||||
static BlockNode[] build(List<BlockNode> sorted, Function<BlockNode, List<BlockNode>> predFunc) {
|
||||
int blocksCount = sorted.size();
|
||||
BlockNode[] doms = new BlockNode[blocksCount];
|
||||
doms[0] = sorted.get(0);
|
||||
@@ -48,7 +46,7 @@ public class DominatorTree {
|
||||
changed = false;
|
||||
for (int blockId = 1; blockId < blocksCount; blockId++) {
|
||||
BlockNode b = sorted.get(blockId);
|
||||
List<BlockNode> preds = b.getPredecessors();
|
||||
List<BlockNode> preds = predFunc.apply(b);
|
||||
int pickedPred = -1;
|
||||
BlockNode newIDom = null;
|
||||
for (BlockNode pred : preds) {
|
||||
@@ -60,7 +58,7 @@ public class DominatorTree {
|
||||
}
|
||||
}
|
||||
if (newIDom == null) {
|
||||
throw new JadxRuntimeException("No predecessors for block: " + b);
|
||||
throw new JadxRuntimeException("No immediate dominator for block: " + b);
|
||||
}
|
||||
for (BlockNode predBlock : preds) {
|
||||
int predId = predBlock.getId();
|
||||
@@ -110,7 +108,7 @@ public class DominatorTree {
|
||||
}
|
||||
}
|
||||
|
||||
private static BitSet collectDoms(BlockNode[] doms, BlockNode idom) {
|
||||
static BitSet collectDoms(BlockNode[] doms, BlockNode idom) {
|
||||
BitSet domBS = new BitSet(doms.length);
|
||||
BlockNode nextIDom = idom;
|
||||
while (true) {
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
package jadx.core.dex.visitors.blocks;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.BitSet;
|
||||
import java.util.List;
|
||||
|
||||
import jadx.core.dex.nodes.BlockNode;
|
||||
import jadx.core.dex.nodes.MethodNode;
|
||||
import jadx.core.utils.BlockUtils;
|
||||
import jadx.core.utils.EmptyBitSet;
|
||||
|
||||
public class PostDominatorTree {
|
||||
|
||||
public static void compute(MethodNode mth) {
|
||||
try {
|
||||
int mthBlocksCount = mth.getBasicBlocks().size();
|
||||
List<BlockNode> sorted = new ArrayList<>(mthBlocksCount);
|
||||
BlockUtils.visitReverseDFS(mth, sorted::add);
|
||||
// temporary set block ids to match reverse sorted order
|
||||
// save old ids for later remapping
|
||||
int blocksCount = sorted.size();
|
||||
int[] ids = new int[mthBlocksCount];
|
||||
for (int i = 0; i < blocksCount; i++) {
|
||||
ids[i] = sorted.get(i).getId();
|
||||
}
|
||||
mth.updateBlockIds(sorted);
|
||||
|
||||
BlockNode[] postDoms = DominatorTree.build(sorted, BlockNode::getSuccessors);
|
||||
BlockNode firstBlock = sorted.get(0);
|
||||
firstBlock.setPostDoms(EmptyBitSet.EMPTY);
|
||||
firstBlock.setIPostDom(null);
|
||||
for (int i = 1; i < blocksCount; i++) {
|
||||
BlockNode block = sorted.get(i);
|
||||
BlockNode iPostDom = postDoms[i];
|
||||
block.setIPostDom(iPostDom);
|
||||
BitSet postDomBS = DominatorTree.collectDoms(postDoms, iPostDom);
|
||||
block.setPostDoms(postDomBS);
|
||||
}
|
||||
for (int i = 1; i < blocksCount; i++) {
|
||||
BlockNode block = sorted.get(i);
|
||||
BitSet bs = new BitSet(blocksCount);
|
||||
block.getPostDoms().stream().forEach(n -> bs.set(ids[n]));
|
||||
bs.clear(ids[i]);
|
||||
block.setPostDoms(bs);
|
||||
}
|
||||
// check for missing blocks in 'sorted' list
|
||||
// can be caused by infinite loops
|
||||
int blocksDelta = mthBlocksCount - blocksCount;
|
||||
if (blocksDelta != 0) {
|
||||
int insnsCount = 0;
|
||||
for (BlockNode block : mth.getBasicBlocks()) {
|
||||
if (block.getPostDoms() == null) {
|
||||
block.setPostDoms(EmptyBitSet.EMPTY);
|
||||
block.setIPostDom(null);
|
||||
insnsCount += block.getInstructions().size();
|
||||
}
|
||||
}
|
||||
mth.addInfoComment("Infinite loop detected, blocks: " + blocksDelta + ", insns: " + insnsCount);
|
||||
}
|
||||
} finally {
|
||||
// revert block ids change
|
||||
mth.updateBlockIds(mth.getBasicBlocks());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package jadx.core.dex.visitors.regions;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.BitSet;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
@@ -14,8 +13,6 @@ import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import jadx.core.dex.attributes.AFlag;
|
||||
import jadx.core.dex.attributes.AType;
|
||||
@@ -60,8 +57,6 @@ import static jadx.core.utils.BlockUtils.getNextBlock;
|
||||
import static jadx.core.utils.BlockUtils.isPathExists;
|
||||
|
||||
public class RegionMaker {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(RegionMaker.class);
|
||||
|
||||
private final MethodNode mth;
|
||||
private final int regionsLimit;
|
||||
private final BitSet processedBlocks;
|
||||
@@ -794,53 +789,32 @@ public class RegionMaker {
|
||||
keys.add(SwitchRegion.DEFAULT_CASE_KEY);
|
||||
}
|
||||
|
||||
// search 'out' block - 'next' block after whole switch statement
|
||||
BlockNode out;
|
||||
LoopInfo loop = mth.getLoopForBlock(block);
|
||||
if (loop == null) {
|
||||
out = calcPostDomOut(mth, block, mth.getPreExitBlocks());
|
||||
} else {
|
||||
BlockNode loopEnd = loop.getEnd();
|
||||
stack.addExit(loop.getStart());
|
||||
if (stack.containsExit(block)
|
||||
|| block == loopEnd
|
||||
|| loopEnd.getPredecessors().contains(block)) {
|
||||
// in exits or last insn in loop => no 'out' block
|
||||
out = null;
|
||||
} else {
|
||||
// treat 'continue' as exit
|
||||
out = calcPostDomOut(mth, block, loopEnd.getPredecessors());
|
||||
if (out != null) {
|
||||
insertContinueInSwitch(block, out, loopEnd);
|
||||
} else {
|
||||
// no 'continue'
|
||||
out = calcPostDomOut(mth, block, Collections.singletonList(loopEnd));
|
||||
}
|
||||
}
|
||||
if (out == loop.getStart()) {
|
||||
// no other outs instead back edge to loop start
|
||||
out = null;
|
||||
}
|
||||
}
|
||||
if (out != null && processedBlocks.get(out.getId())) {
|
||||
// out block already processed, prevent endless loop
|
||||
throw new JadxRuntimeException("Failed to find switch 'out' block");
|
||||
}
|
||||
|
||||
SwitchRegion sw = new SwitchRegion(currentRegion, block);
|
||||
insn.addAttr(new RegionRefAttr(sw));
|
||||
currentRegion.getSubBlocks().add(sw);
|
||||
stack.push(sw);
|
||||
|
||||
BlockNode out = calcSwitchOut(block, stack);
|
||||
stack.addExit(out);
|
||||
|
||||
// detect fallthrough cases
|
||||
processFallThroughCases(sw, out, stack, blocksMap);
|
||||
removeEmptyCases(insn, sw, defCase);
|
||||
|
||||
stack.pop();
|
||||
return out;
|
||||
}
|
||||
|
||||
private void processFallThroughCases(SwitchRegion sw, @Nullable BlockNode out,
|
||||
RegionStack stack, Map<BlockNode, List<Object>> blocksMap) {
|
||||
Map<BlockNode, BlockNode> fallThroughCases = new LinkedHashMap<>();
|
||||
if (out != null) {
|
||||
// detect fallthrough cases
|
||||
BitSet caseBlocks = BlockUtils.blocksToBitSet(mth, blocksMap.keySet());
|
||||
caseBlocks.clear(out.getId());
|
||||
for (BlockNode successor : block.getCleanSuccessors()) {
|
||||
BlockNode fallThroughBlock = searchFallThroughCase(successor, out, caseBlocks);
|
||||
if (fallThroughBlock != null) {
|
||||
for (BlockNode successor : sw.getHeader().getCleanSuccessors()) {
|
||||
BitSet df = successor.getDomFrontier();
|
||||
if (df.intersects(caseBlocks)) {
|
||||
BlockNode fallThroughBlock = getOneIntersectionBlock(out, caseBlocks, df);
|
||||
fallThroughCases.put(successor, fallThroughBlock);
|
||||
}
|
||||
}
|
||||
@@ -874,26 +848,6 @@ public class RegionMaker {
|
||||
// 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor
|
||||
}
|
||||
}
|
||||
|
||||
removeEmptyCases(insn, sw, defCase);
|
||||
|
||||
stack.pop();
|
||||
return out;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private BlockNode searchFallThroughCase(BlockNode successor, BlockNode out, BitSet caseBlocks) {
|
||||
BitSet df = successor.getDomFrontier();
|
||||
if (df.intersects(caseBlocks)) {
|
||||
return getOneIntersectionBlock(out, caseBlocks, df);
|
||||
}
|
||||
Set<BlockNode> allPathsBlocks = BlockUtils.getAllPathsBlocks(successor, out);
|
||||
Map<BlockNode, BitSet> bitSetMap = BlockUtils.calcPartialPostDominance(mth, allPathsBlocks, out);
|
||||
BitSet pdoms = bitSetMap.get(successor);
|
||||
if (pdoms != null && pdoms.intersects(caseBlocks)) {
|
||||
return getOneIntersectionBlock(out, caseBlocks, pdoms);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@@ -904,35 +858,71 @@ public class RegionMaker {
|
||||
return BlockUtils.bitSetToOneBlock(mth, caseExits);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private static BlockNode calcPostDomOut(MethodNode mth, BlockNode block, List<BlockNode> exits) {
|
||||
if (exits.size() == 1 && mth.getExitBlock().equals(exits.get(0))) {
|
||||
// simple case: for only one exit which is equal to method exit block
|
||||
return BlockUtils.calcImmediatePostDominator(mth, block);
|
||||
}
|
||||
// fast search: union of blocks dominance frontier
|
||||
// work if no fallthrough cases and no returns inside switch
|
||||
BitSet outs = BlockUtils.copyBlocksBitSet(mth, block.getDomFrontier());
|
||||
private @Nullable BlockNode calcSwitchOut(BlockNode block, RegionStack stack) {
|
||||
// union of case blocks dominance frontier
|
||||
// works if no fallthrough cases and no returns inside switch
|
||||
BitSet outs = BlockUtils.newBlocksBitSet(mth);
|
||||
for (BlockNode s : block.getCleanSuccessors()) {
|
||||
outs.or(s.getDomFrontier());
|
||||
}
|
||||
outs.clear(block.getId());
|
||||
if (outs.isEmpty()) {
|
||||
// switch already contains method exit
|
||||
// add everything, out block not needed
|
||||
return mth.getExitBlock();
|
||||
}
|
||||
|
||||
if (outs.cardinality() != 1) {
|
||||
// slow search: calculate partial post-dominance for every exit node
|
||||
BitSet ipdoms = BlockUtils.newBlocksBitSet(mth);
|
||||
for (BlockNode exitBlock : exits) {
|
||||
if (BlockUtils.isAnyPathExists(block, exitBlock)) {
|
||||
Set<BlockNode> pathBlocks = BlockUtils.getAllPathsBlocks(block, exitBlock);
|
||||
BlockNode ipdom = BlockUtils.calcPartialImmediatePostDominator(mth, block, pathBlocks, exitBlock);
|
||||
if (ipdom != null) {
|
||||
ipdoms.set(ipdom.getId());
|
||||
BlockNode out;
|
||||
if (outs.cardinality() == 1) {
|
||||
// single exit
|
||||
out = BlockUtils.bitSetToOneBlock(mth, outs);
|
||||
} else {
|
||||
// several switch exits
|
||||
// possible 'return', 'continue' or fallthrough in one of the cases
|
||||
LoopInfo loop = mth.getLoopForBlock(block);
|
||||
if (loop != null) {
|
||||
outs.andNot(block.getPostDoms());
|
||||
out = BlockUtils.bitSetToOneBlock(mth, outs);
|
||||
if (out != null) {
|
||||
insertContinueInSwitch(block, out, loop.getEnd());
|
||||
if (out == loop.getStart()) {
|
||||
// no other outs instead back edge to loop start
|
||||
return null;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
outs.clear(mth.getExitBlock().getId());
|
||||
BlockNode imPostDom = block.getIPostDom();
|
||||
if (outs.get(imPostDom.getId())) {
|
||||
return imPostDom;
|
||||
}
|
||||
outs.andNot(block.getPostDoms());
|
||||
out = BlockUtils.bitSetToOneBlock(mth, outs);
|
||||
}
|
||||
outs.and(ipdoms);
|
||||
}
|
||||
return BlockUtils.bitSetToOneBlock(mth, outs);
|
||||
if (out != null && mth.isPreExitBlock(out)) {
|
||||
// include 'return' or 'throw' in case blocks
|
||||
out = mth.getExitBlock();
|
||||
}
|
||||
BlockNode imPostDom = block.getIPostDom();
|
||||
if (out != imPostDom && !mth.isPreExitBlock(imPostDom)) {
|
||||
// stop other paths at common exit
|
||||
stack.addExit(imPostDom);
|
||||
}
|
||||
if (block.getCleanSuccessors().contains(imPostDom)) {
|
||||
// add exit to stop on empty 'default' block
|
||||
stack.addExit(imPostDom);
|
||||
}
|
||||
if (out == null) {
|
||||
mth.addWarnComment("Failed to find 'out' block for switch in " + block + ". Please report as an issue.");
|
||||
// fallback option; should work in most cases
|
||||
out = block.getIPostDom();
|
||||
}
|
||||
if (out != null && processedBlocks.get(out.getId())) {
|
||||
// 'out' block already processed, prevent endless loop
|
||||
throw new JadxRuntimeException("Failed to find switch 'out' block (already processed)");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -999,18 +989,21 @@ public class RegionMaker {
|
||||
return newBlocksMap;
|
||||
}
|
||||
|
||||
private void insertContinueInSwitch(BlockNode block, BlockNode out, BlockNode end) {
|
||||
int endId = end.getId();
|
||||
for (BlockNode s : block.getCleanSuccessors()) {
|
||||
if (s.getDomFrontier().get(endId) && s != out) {
|
||||
private void insertContinueInSwitch(BlockNode switchBlock, BlockNode switchOut, BlockNode loopEnd) {
|
||||
for (BlockNode caseBlock : switchBlock.getCleanSuccessors()) {
|
||||
if (caseBlock.getDomFrontier().get(loopEnd.getId()) && caseBlock != switchOut) {
|
||||
// search predecessor of loop end on path from this successor
|
||||
List<BlockNode> list = BlockUtils.collectBlocksDominatedBy(mth, s, s);
|
||||
for (BlockNode p : end.getPredecessors()) {
|
||||
if (list.contains(p)) {
|
||||
if (p.isSynthetic()) {
|
||||
p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0));
|
||||
Set<BlockNode> list = new HashSet<>(BlockUtils.collectBlocksDominatedBy(mth, caseBlock, caseBlock));
|
||||
if (list.contains(switchOut) || switchOut.getPredecessors().stream().anyMatch(list::contains)) {
|
||||
// 'continue' not needed
|
||||
} else {
|
||||
for (BlockNode p : loopEnd.getPredecessors()) {
|
||||
if (list.contains(p)) {
|
||||
if (p.isSynthetic()) {
|
||||
p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0));
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,14 +7,13 @@ import java.util.BitSet;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Deque;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Queue;
|
||||
import java.util.Set;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
@@ -37,6 +36,7 @@ import jadx.core.dex.nodes.MethodNode;
|
||||
import jadx.core.dex.regions.conditions.IfCondition;
|
||||
import jadx.core.dex.trycatch.CatchAttr;
|
||||
import jadx.core.dex.trycatch.ExceptionHandler;
|
||||
import jadx.core.utils.blocks.BlockSet;
|
||||
import jadx.core.utils.exceptions.JadxRuntimeException;
|
||||
|
||||
public class BlockUtils {
|
||||
@@ -483,33 +483,37 @@ public class BlockUtils {
|
||||
|
||||
public static List<BlockNode> collectAllSuccessors(MethodNode mth, BlockNode startBlock, boolean clean) {
|
||||
List<BlockNode> list = new ArrayList<>(mth.getBasicBlocks().size());
|
||||
dfsVisit(mth, startBlock, clean, list::add);
|
||||
Function<BlockNode, List<BlockNode>> nextFunc = clean ? BlockNode::getCleanSuccessors : BlockNode::getSuccessors;
|
||||
visitDFS(mth, startBlock, nextFunc, list::add);
|
||||
return list;
|
||||
}
|
||||
|
||||
public static void dfsVisit(MethodNode mth, Consumer<BlockNode> visitor) {
|
||||
dfsVisit(mth, mth.getEnterBlock(), false, visitor);
|
||||
public static void visitDFS(MethodNode mth, Consumer<BlockNode> visitor) {
|
||||
visitDFS(mth, mth.getEnterBlock(), BlockNode::getSuccessors, visitor);
|
||||
}
|
||||
|
||||
private static void dfsVisit(MethodNode mth, BlockNode startBlock, boolean clean, Consumer<BlockNode> visitor) {
|
||||
BitSet visited = newBlocksBitSet(mth);
|
||||
public static void visitReverseDFS(MethodNode mth, Consumer<BlockNode> visitor) {
|
||||
visitDFS(mth, mth.getExitBlock(), BlockNode::getPredecessors, visitor);
|
||||
}
|
||||
|
||||
private static void visitDFS(MethodNode mth, BlockNode startBlock,
|
||||
Function<BlockNode, List<BlockNode>> nextFunc, Consumer<BlockNode> visitor) {
|
||||
BlockSet visited = new BlockSet(mth);
|
||||
Deque<BlockNode> queue = new ArrayDeque<>();
|
||||
queue.addLast(startBlock);
|
||||
visited.set(startBlock.getId());
|
||||
visited.set(startBlock);
|
||||
while (true) {
|
||||
BlockNode current = queue.pollLast();
|
||||
if (current == null) {
|
||||
return;
|
||||
}
|
||||
visitor.accept(current);
|
||||
List<BlockNode> successors = clean ? current.getCleanSuccessors() : current.getSuccessors();
|
||||
int count = successors.size();
|
||||
List<BlockNode> nextBlocks = nextFunc.apply(current);
|
||||
int count = nextBlocks.size();
|
||||
for (int i = count - 1; i >= 0; i--) { // to preserve order in queue
|
||||
BlockNode next = successors.get(i);
|
||||
int nextId = next.getId();
|
||||
if (!visited.get(nextId)) {
|
||||
BlockNode next = nextBlocks.get(i);
|
||||
if (!visited.checkAndSet(next)) {
|
||||
queue.addLast(next);
|
||||
visited.set(nextId);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1156,104 +1160,6 @@ public class BlockUtils {
|
||||
return false;
|
||||
}
|
||||
|
||||
public static Map<BlockNode, BitSet> calcPostDominance(MethodNode mth) {
|
||||
return calcPartialPostDominance(mth, mth.getBasicBlocks(), mth.getPreExitBlocks().get(0));
|
||||
}
|
||||
|
||||
public static Map<BlockNode, BitSet> calcPartialPostDominance(MethodNode mth, Collection<BlockNode> blockNodes, BlockNode exitBlock) {
|
||||
int blocksCount = mth.getBasicBlocks().size();
|
||||
Map<BlockNode, BitSet> map = new HashMap<>(blocksCount);
|
||||
|
||||
BitSet initSet = new BitSet(blocksCount);
|
||||
for (BlockNode block : blockNodes) {
|
||||
initSet.set(block.getId());
|
||||
}
|
||||
|
||||
for (BlockNode block : blockNodes) {
|
||||
BitSet postDoms = new BitSet(blocksCount);
|
||||
postDoms.or(initSet);
|
||||
map.put(block, postDoms);
|
||||
}
|
||||
BitSet exitBitSet = map.get(exitBlock);
|
||||
exitBitSet.clear();
|
||||
exitBitSet.set(exitBlock.getId());
|
||||
|
||||
BitSet domSet = new BitSet(blocksCount);
|
||||
boolean changed;
|
||||
do {
|
||||
changed = false;
|
||||
for (BlockNode block : blockNodes) {
|
||||
if (block == exitBlock) {
|
||||
continue;
|
||||
}
|
||||
BitSet d = map.get(block);
|
||||
if (!changed) {
|
||||
domSet.clear();
|
||||
domSet.or(d);
|
||||
}
|
||||
for (BlockNode scc : block.getSuccessors()) {
|
||||
BitSet scPDoms = map.get(scc);
|
||||
if (scPDoms != null) {
|
||||
d.and(scPDoms);
|
||||
}
|
||||
}
|
||||
d.set(block.getId());
|
||||
if (!changed && !d.equals(domSet)) {
|
||||
changed = true;
|
||||
map.put(block, d);
|
||||
}
|
||||
}
|
||||
} while (changed);
|
||||
|
||||
blockNodes.forEach(block -> {
|
||||
BitSet postDoms = map.get(block);
|
||||
postDoms.clear(block.getId());
|
||||
if (postDoms.isEmpty()) {
|
||||
map.put(block, EmptyBitSet.EMPTY);
|
||||
}
|
||||
});
|
||||
return map;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public static BlockNode calcImmediatePostDominator(MethodNode mth, BlockNode block) {
|
||||
BlockNode oneSuccessor = Utils.getOne(block.getSuccessors());
|
||||
if (oneSuccessor != null) {
|
||||
return oneSuccessor;
|
||||
}
|
||||
return calcImmediatePostDominator(mth, block, calcPostDominance(mth));
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public static BlockNode calcPartialImmediatePostDominator(MethodNode mth, BlockNode block,
|
||||
Collection<BlockNode> blockNodes, BlockNode exitBlock) {
|
||||
BlockNode oneSuccessor = Utils.getOne(block.getSuccessors());
|
||||
if (oneSuccessor != null) {
|
||||
return oneSuccessor;
|
||||
}
|
||||
Map<BlockNode, BitSet> pDomsMap = calcPartialPostDominance(mth, blockNodes, exitBlock);
|
||||
return calcImmediatePostDominator(mth, block, pDomsMap);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public static BlockNode calcImmediatePostDominator(MethodNode mth, BlockNode block, Map<BlockNode, BitSet> postDomsMap) {
|
||||
BlockNode oneSuccessor = Utils.getOne(block.getSuccessors());
|
||||
if (oneSuccessor != null) {
|
||||
return oneSuccessor;
|
||||
}
|
||||
List<BlockNode> basicBlocks = mth.getBasicBlocks();
|
||||
BitSet postDoms = postDomsMap.get(block);
|
||||
BitSet bs = copyBlocksBitSet(mth, postDoms);
|
||||
for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
|
||||
BlockNode pdomBlock = basicBlocks.get(i);
|
||||
BitSet pdoms = postDomsMap.get(pdomBlock);
|
||||
if (pdoms != null) {
|
||||
bs.andNot(pdoms);
|
||||
}
|
||||
}
|
||||
return bitSetToOneBlock(mth, bs);
|
||||
}
|
||||
|
||||
public static BlockNode getTopSplitterForHandler(BlockNode handlerBlock) {
|
||||
BlockNode block = getBlockWithFlag(handlerBlock.getPredecessors(), AFlag.EXC_TOP_SPLITTER);
|
||||
if (block == null) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import org.junit.jupiter.api.Test;
|
||||
import jadx.tests.api.IntegrationTest;
|
||||
|
||||
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.catchThrowable;
|
||||
|
||||
public class TestSwitchWithThrow extends IntegrationTest {
|
||||
|
||||
@@ -19,19 +20,19 @@ public class TestSwitchWithThrow extends IntegrationTest {
|
||||
default:
|
||||
throw new IllegalStateException("Other");
|
||||
}
|
||||
} else {
|
||||
System.out.println("0");
|
||||
return -1;
|
||||
}
|
||||
System.out.println("0");
|
||||
return -1;
|
||||
}
|
||||
|
||||
public void check() {
|
||||
assertThat(test(0)).isEqualTo(-1);
|
||||
// TODO: implement 'invoke-custom' support
|
||||
// assertThat(catchThrowable(() -> test(1)))
|
||||
// .isInstanceOf(IllegalStateException.class).hasMessageContaining("1");
|
||||
// assertThat(catchThrowable(() -> test(3)))
|
||||
// .isInstanceOf(IllegalStateException.class).hasMessageContaining("Other");
|
||||
assertThat(catchThrowable(() -> test(1)))
|
||||
.isInstanceOf(IllegalStateException.class)
|
||||
.hasMessageContaining("1");
|
||||
assertThat(catchThrowable(() -> test(3)))
|
||||
.isInstanceOf(IllegalStateException.class)
|
||||
.hasMessageContaining("Other");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user