fix: optimize switch fallthrough (PR #2054)

* cache post dom map between switch cases
* cache post dom map of whole methods
* calculate full post dom tree, fix switch out block detection

---------

Co-authored-by: Skylot <skylot@gmail.com>
This commit is contained in:
DanielFi
2024-02-14 20:31:38 +02:00
committed by GitHub
parent 0143423dc9
commit 0c33d723c8
12 changed files with 237 additions and 236 deletions
@@ -56,7 +56,7 @@ public class SimpleModeHelper {
if (!block.contains(AFlag.EXC_BOTTOM_SPLITTER)) {
startLabel.set(block.getId());
}
if (prev.getSuccessors().size() == 1 && !mth.isPreExitBlocks(prev)) {
if (prev.getSuccessors().size() == 1 && !mth.isPreExitBlock(prev)) {
endGoto.set(prev.getId());
}
}
@@ -68,7 +68,7 @@ public class SimpleModeHelper {
if (block.contains(AType.EXC_HANDLER)) {
startLabel.set(block.getId());
}
if (nextBlock == null && !mth.isPreExitBlocks(block)) {
if (nextBlock == null && !mth.isPreExitBlock(block)) {
endGoto.set(block.getId());
}
prev = block;
@@ -145,7 +145,7 @@ public class SimpleModeHelper {
// DFS sort blocks to reduce goto count
private List<BlockNode> getSortedBlocks() {
List<BlockNode> list = new ArrayList<>(mth.getBasicBlocks().size());
BlockUtils.dfsVisit(mth, list::add);
BlockUtils.visitDFS(mth, list::add);
return list;
}
}
@@ -46,6 +46,11 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
*/
private BitSet doms = EmptyBitSet.EMPTY;
/**
* Post dominators, excluding self
*/
private BitSet postDoms = EmptyBitSet.EMPTY;
/**
* Dominance frontier
*/
@@ -56,6 +61,11 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
*/
private BlockNode idom;
/**
* Immediate post dominator
*/
private BlockNode iPostDom;
/**
* Blocks on which dominates this block
*/
@@ -165,6 +175,14 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
this.doms = doms;
}
public BitSet getPostDoms() {
return postDoms;
}
public void setPostDoms(BitSet postDoms) {
this.postDoms = postDoms;
}
public BitSet getDomFrontier() {
return domFrontier;
}
@@ -184,6 +202,14 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
this.idom = idom;
}
public BlockNode getIPostDom() {
return iPostDom;
}
public void setIPostDom(BlockNode iPostDom) {
this.iPostDom = iPostDom;
}
public List<BlockNode> getDominatesOn() {
return dominatesOn;
}
@@ -233,6 +259,6 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
@Override
public String toString() {
return "B:" + cid + ':' + InsnUtils.formatOffset(startOffset);
return "B:" + id + ':' + InsnUtils.formatOffset(startOffset);
}
}
@@ -360,10 +360,13 @@ public class MethodNode extends NotificationAttrNode implements IMethodDetails,
public void setBasicBlocks(List<BlockNode> blocks) {
this.blocks = blocks;
int i = 0;
for (BlockNode block : blocks) {
block.setId(i);
i++;
updateBlockIds(blocks);
}
public void updateBlockIds(List<BlockNode> blocks) {
int count = blocks.size();
for (int i = 0; i < count; i++) {
blocks.get(i).setId(i);
}
}
@@ -391,7 +394,7 @@ public class MethodNode extends NotificationAttrNode implements IMethodDetails,
return exitBlock.getPredecessors();
}
public boolean isPreExitBlocks(BlockNode block) {
public boolean isPreExitBlock(BlockNode block) {
List<BlockNode> successors = block.getSuccessors();
if (successors.size() == 1) {
return successors.get(0).equals(exitBlock);
@@ -15,7 +15,12 @@ import jadx.core.utils.exceptions.CodegenException;
public final class SwitchRegion extends AbstractRegion implements IBranchRegion {
public static final Object DEFAULT_CASE_KEY = new Object();
public static final Object DEFAULT_CASE_KEY = new Object() {
@Override
public String toString() {
return "default";
}
};
private final BlockNode header;
@@ -199,7 +199,7 @@ public class DotGraphVisitor extends AbstractVisitor {
dot.add("color=red,");
}
dot.add("label=\"{");
dot.add(String.valueOf(block.getCId())).add("\\:\\ ");
dot.add(String.valueOf(block.getId())).add("\\:\\ ");
dot.add(InsnUtils.formatOffset(block.getStartOffset()));
if (!attrs.isEmpty()) {
dot.add('|').add(attrs);
@@ -208,6 +208,8 @@ public class DotGraphVisitor extends AbstractVisitor {
dot.add('|');
dot.startLine("doms: ").add(escape(block.getDoms()));
dot.startLine("\\lidom: ").add(escape(block.getIDom()));
dot.startLine("\\lpost-doms: ").add(escape(block.getPostDoms()));
dot.startLine("\\lpost-idom: ").add(escape(block.getIPostDom()));
dot.startLine("\\ldom-f: ").add(escape(block.getDomFrontier()));
dot.startLine("\\ldoms-on: ").add(escape(Utils.listToString(block.getDominatesOn())));
dot.startLine("\\l");
@@ -230,10 +232,10 @@ public class DotGraphVisitor extends AbstractVisitor {
if (PRINT_DOMINATORS) {
for (BlockNode c : block.getDominatesOn()) {
conn.startLine(block.getCId() + " -> " + c.getCId() + "[color=green];");
conn.startLine(block.getId() + " -> " + c.getId() + "[color=green];");
}
for (BlockNode dom : BlockUtils.bitSetToBlocks(mth, block.getDomFrontier())) {
conn.startLine("f_" + block.getCId() + " -> f_" + dom.getCId() + "[color=blue];");
conn.startLine("f_" + block.getId() + " -> f_" + dom.getId() + "[color=blue];");
}
}
}
@@ -273,7 +275,7 @@ public class DotGraphVisitor extends AbstractVisitor {
private String makeName(IContainer c) {
String name;
if (c instanceof BlockNode) {
name = "Node_" + ((BlockNode) c).getCId();
name = "Node_" + ((BlockNode) c).getId();
} else if (c instanceof IBlock) {
name = "Node_" + c.getClass().getSimpleName() + '_' + c.hashCode();
} else {
@@ -68,7 +68,7 @@ public class BlockExceptionHandler {
BlockProcessor.removeMarkedBlocks(mth);
BlockSet sorted = new BlockSet(mth);
BlockUtils.dfsVisit(mth, sorted::set);
BlockUtils.visitDFS(mth, sorted::set);
removeUnusedExcHandlers(mth, tryBlocks, sorted);
return true;
}
@@ -70,6 +70,8 @@ public class BlockProcessor extends AbstractVisitor {
registerLoops(mth);
processNestedLoops(mth);
PostDominatorTree.compute(mth);
updateCleanSuccessors(mth);
if (!mth.contains(AFlag.DISABLE_BLOCKS_LOCK)) {
mth.finishBasicBlocks();
@@ -3,8 +3,7 @@ package jadx.core.dex.visitors.blocks;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import org.jetbrains.annotations.NotNull;
import java.util.function.Function;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.MethodNode;
@@ -23,14 +22,14 @@ public class DominatorTree {
public static void compute(MethodNode mth) {
List<BlockNode> sorted = sortBlocks(mth);
BlockNode[] doms = build(sorted);
BlockNode[] doms = build(sorted, BlockNode::getPredecessors);
apply(sorted, doms);
}
private static List<BlockNode> sortBlocks(MethodNode mth) {
int blocksCount = mth.getBasicBlocks().size();
List<BlockNode> sorted = new ArrayList<>(blocksCount);
BlockUtils.dfsVisit(mth, sorted::add);
BlockUtils.visitDFS(mth, sorted::add);
if (sorted.size() != blocksCount) {
throw new JadxRuntimeException("Found unreachable blocks");
}
@@ -38,8 +37,7 @@ public class DominatorTree {
return sorted;
}
@NotNull
private static BlockNode[] build(List<BlockNode> sorted) {
static BlockNode[] build(List<BlockNode> sorted, Function<BlockNode, List<BlockNode>> predFunc) {
int blocksCount = sorted.size();
BlockNode[] doms = new BlockNode[blocksCount];
doms[0] = sorted.get(0);
@@ -48,7 +46,7 @@ public class DominatorTree {
changed = false;
for (int blockId = 1; blockId < blocksCount; blockId++) {
BlockNode b = sorted.get(blockId);
List<BlockNode> preds = b.getPredecessors();
List<BlockNode> preds = predFunc.apply(b);
int pickedPred = -1;
BlockNode newIDom = null;
for (BlockNode pred : preds) {
@@ -60,7 +58,7 @@ public class DominatorTree {
}
}
if (newIDom == null) {
throw new JadxRuntimeException("No predecessors for block: " + b);
throw new JadxRuntimeException("No immediate dominator for block: " + b);
}
for (BlockNode predBlock : preds) {
int predId = predBlock.getId();
@@ -110,7 +108,7 @@ public class DominatorTree {
}
}
private static BitSet collectDoms(BlockNode[] doms, BlockNode idom) {
static BitSet collectDoms(BlockNode[] doms, BlockNode idom) {
BitSet domBS = new BitSet(doms.length);
BlockNode nextIDom = idom;
while (true) {
@@ -0,0 +1,65 @@
package jadx.core.dex.visitors.blocks;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.EmptyBitSet;
public class PostDominatorTree {
public static void compute(MethodNode mth) {
try {
int mthBlocksCount = mth.getBasicBlocks().size();
List<BlockNode> sorted = new ArrayList<>(mthBlocksCount);
BlockUtils.visitReverseDFS(mth, sorted::add);
// temporary set block ids to match reverse sorted order
// save old ids for later remapping
int blocksCount = sorted.size();
int[] ids = new int[mthBlocksCount];
for (int i = 0; i < blocksCount; i++) {
ids[i] = sorted.get(i).getId();
}
mth.updateBlockIds(sorted);
BlockNode[] postDoms = DominatorTree.build(sorted, BlockNode::getSuccessors);
BlockNode firstBlock = sorted.get(0);
firstBlock.setPostDoms(EmptyBitSet.EMPTY);
firstBlock.setIPostDom(null);
for (int i = 1; i < blocksCount; i++) {
BlockNode block = sorted.get(i);
BlockNode iPostDom = postDoms[i];
block.setIPostDom(iPostDom);
BitSet postDomBS = DominatorTree.collectDoms(postDoms, iPostDom);
block.setPostDoms(postDomBS);
}
for (int i = 1; i < blocksCount; i++) {
BlockNode block = sorted.get(i);
BitSet bs = new BitSet(blocksCount);
block.getPostDoms().stream().forEach(n -> bs.set(ids[n]));
bs.clear(ids[i]);
block.setPostDoms(bs);
}
// check for missing blocks in 'sorted' list
// can be caused by infinite loops
int blocksDelta = mthBlocksCount - blocksCount;
if (blocksDelta != 0) {
int insnsCount = 0;
for (BlockNode block : mth.getBasicBlocks()) {
if (block.getPostDoms() == null) {
block.setPostDoms(EmptyBitSet.EMPTY);
block.setIPostDom(null);
insnsCount += block.getInstructions().size();
}
}
mth.addInfoComment("Infinite loop detected, blocks: " + blocksDelta + ", insns: " + insnsCount);
}
} finally {
// revert block ids change
mth.updateBlockIds(mth.getBasicBlocks());
}
}
}
@@ -2,7 +2,6 @@ package jadx.core.dex.visitors.regions;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
@@ -14,8 +13,6 @@ import java.util.Optional;
import java.util.Set;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
@@ -60,8 +57,6 @@ import static jadx.core.utils.BlockUtils.getNextBlock;
import static jadx.core.utils.BlockUtils.isPathExists;
public class RegionMaker {
private static final Logger LOG = LoggerFactory.getLogger(RegionMaker.class);
private final MethodNode mth;
private final int regionsLimit;
private final BitSet processedBlocks;
@@ -794,53 +789,32 @@ public class RegionMaker {
keys.add(SwitchRegion.DEFAULT_CASE_KEY);
}
// search 'out' block - 'next' block after whole switch statement
BlockNode out;
LoopInfo loop = mth.getLoopForBlock(block);
if (loop == null) {
out = calcPostDomOut(mth, block, mth.getPreExitBlocks());
} else {
BlockNode loopEnd = loop.getEnd();
stack.addExit(loop.getStart());
if (stack.containsExit(block)
|| block == loopEnd
|| loopEnd.getPredecessors().contains(block)) {
// in exits or last insn in loop => no 'out' block
out = null;
} else {
// treat 'continue' as exit
out = calcPostDomOut(mth, block, loopEnd.getPredecessors());
if (out != null) {
insertContinueInSwitch(block, out, loopEnd);
} else {
// no 'continue'
out = calcPostDomOut(mth, block, Collections.singletonList(loopEnd));
}
}
if (out == loop.getStart()) {
// no other outs instead back edge to loop start
out = null;
}
}
if (out != null && processedBlocks.get(out.getId())) {
// out block already processed, prevent endless loop
throw new JadxRuntimeException("Failed to find switch 'out' block");
}
SwitchRegion sw = new SwitchRegion(currentRegion, block);
insn.addAttr(new RegionRefAttr(sw));
currentRegion.getSubBlocks().add(sw);
stack.push(sw);
BlockNode out = calcSwitchOut(block, stack);
stack.addExit(out);
// detect fallthrough cases
processFallThroughCases(sw, out, stack, blocksMap);
removeEmptyCases(insn, sw, defCase);
stack.pop();
return out;
}
private void processFallThroughCases(SwitchRegion sw, @Nullable BlockNode out,
RegionStack stack, Map<BlockNode, List<Object>> blocksMap) {
Map<BlockNode, BlockNode> fallThroughCases = new LinkedHashMap<>();
if (out != null) {
// detect fallthrough cases
BitSet caseBlocks = BlockUtils.blocksToBitSet(mth, blocksMap.keySet());
caseBlocks.clear(out.getId());
for (BlockNode successor : block.getCleanSuccessors()) {
BlockNode fallThroughBlock = searchFallThroughCase(successor, out, caseBlocks);
if (fallThroughBlock != null) {
for (BlockNode successor : sw.getHeader().getCleanSuccessors()) {
BitSet df = successor.getDomFrontier();
if (df.intersects(caseBlocks)) {
BlockNode fallThroughBlock = getOneIntersectionBlock(out, caseBlocks, df);
fallThroughCases.put(successor, fallThroughBlock);
}
}
@@ -874,26 +848,6 @@ public class RegionMaker {
// 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor
}
}
removeEmptyCases(insn, sw, defCase);
stack.pop();
return out;
}
@Nullable
private BlockNode searchFallThroughCase(BlockNode successor, BlockNode out, BitSet caseBlocks) {
BitSet df = successor.getDomFrontier();
if (df.intersects(caseBlocks)) {
return getOneIntersectionBlock(out, caseBlocks, df);
}
Set<BlockNode> allPathsBlocks = BlockUtils.getAllPathsBlocks(successor, out);
Map<BlockNode, BitSet> bitSetMap = BlockUtils.calcPartialPostDominance(mth, allPathsBlocks, out);
BitSet pdoms = bitSetMap.get(successor);
if (pdoms != null && pdoms.intersects(caseBlocks)) {
return getOneIntersectionBlock(out, caseBlocks, pdoms);
}
return null;
}
@Nullable
@@ -904,35 +858,71 @@ public class RegionMaker {
return BlockUtils.bitSetToOneBlock(mth, caseExits);
}
@Nullable
private static BlockNode calcPostDomOut(MethodNode mth, BlockNode block, List<BlockNode> exits) {
if (exits.size() == 1 && mth.getExitBlock().equals(exits.get(0))) {
// simple case: for only one exit which is equal to method exit block
return BlockUtils.calcImmediatePostDominator(mth, block);
}
// fast search: union of blocks dominance frontier
// work if no fallthrough cases and no returns inside switch
BitSet outs = BlockUtils.copyBlocksBitSet(mth, block.getDomFrontier());
private @Nullable BlockNode calcSwitchOut(BlockNode block, RegionStack stack) {
// union of case blocks dominance frontier
// works if no fallthrough cases and no returns inside switch
BitSet outs = BlockUtils.newBlocksBitSet(mth);
for (BlockNode s : block.getCleanSuccessors()) {
outs.or(s.getDomFrontier());
}
outs.clear(block.getId());
if (outs.isEmpty()) {
// switch already contains method exit
// add everything, out block not needed
return mth.getExitBlock();
}
if (outs.cardinality() != 1) {
// slow search: calculate partial post-dominance for every exit node
BitSet ipdoms = BlockUtils.newBlocksBitSet(mth);
for (BlockNode exitBlock : exits) {
if (BlockUtils.isAnyPathExists(block, exitBlock)) {
Set<BlockNode> pathBlocks = BlockUtils.getAllPathsBlocks(block, exitBlock);
BlockNode ipdom = BlockUtils.calcPartialImmediatePostDominator(mth, block, pathBlocks, exitBlock);
if (ipdom != null) {
ipdoms.set(ipdom.getId());
BlockNode out;
if (outs.cardinality() == 1) {
// single exit
out = BlockUtils.bitSetToOneBlock(mth, outs);
} else {
// several switch exits
// possible 'return', 'continue' or fallthrough in one of the cases
LoopInfo loop = mth.getLoopForBlock(block);
if (loop != null) {
outs.andNot(block.getPostDoms());
out = BlockUtils.bitSetToOneBlock(mth, outs);
if (out != null) {
insertContinueInSwitch(block, out, loop.getEnd());
if (out == loop.getStart()) {
// no other outs instead back edge to loop start
return null;
}
}
} else {
outs.clear(mth.getExitBlock().getId());
BlockNode imPostDom = block.getIPostDom();
if (outs.get(imPostDom.getId())) {
return imPostDom;
}
outs.andNot(block.getPostDoms());
out = BlockUtils.bitSetToOneBlock(mth, outs);
}
outs.and(ipdoms);
}
return BlockUtils.bitSetToOneBlock(mth, outs);
if (out != null && mth.isPreExitBlock(out)) {
// include 'return' or 'throw' in case blocks
out = mth.getExitBlock();
}
BlockNode imPostDom = block.getIPostDom();
if (out != imPostDom && !mth.isPreExitBlock(imPostDom)) {
// stop other paths at common exit
stack.addExit(imPostDom);
}
if (block.getCleanSuccessors().contains(imPostDom)) {
// add exit to stop on empty 'default' block
stack.addExit(imPostDom);
}
if (out == null) {
mth.addWarnComment("Failed to find 'out' block for switch in " + block + ". Please report as an issue.");
// fallback option; should work in most cases
out = block.getIPostDom();
}
if (out != null && processedBlocks.get(out.getId())) {
// 'out' block already processed, prevent endless loop
throw new JadxRuntimeException("Failed to find switch 'out' block (already processed)");
}
return out;
}
/**
@@ -999,18 +989,21 @@ public class RegionMaker {
return newBlocksMap;
}
private void insertContinueInSwitch(BlockNode block, BlockNode out, BlockNode end) {
int endId = end.getId();
for (BlockNode s : block.getCleanSuccessors()) {
if (s.getDomFrontier().get(endId) && s != out) {
private void insertContinueInSwitch(BlockNode switchBlock, BlockNode switchOut, BlockNode loopEnd) {
for (BlockNode caseBlock : switchBlock.getCleanSuccessors()) {
if (caseBlock.getDomFrontier().get(loopEnd.getId()) && caseBlock != switchOut) {
// search predecessor of loop end on path from this successor
List<BlockNode> list = BlockUtils.collectBlocksDominatedBy(mth, s, s);
for (BlockNode p : end.getPredecessors()) {
if (list.contains(p)) {
if (p.isSynthetic()) {
p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0));
Set<BlockNode> list = new HashSet<>(BlockUtils.collectBlocksDominatedBy(mth, caseBlock, caseBlock));
if (list.contains(switchOut) || switchOut.getPredecessors().stream().anyMatch(list::contains)) {
// 'continue' not needed
} else {
for (BlockNode p : loopEnd.getPredecessors()) {
if (list.contains(p)) {
if (p.isSynthetic()) {
p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0));
}
break;
}
break;
}
}
}
@@ -7,14 +7,13 @@ import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import org.jetbrains.annotations.Nullable;
@@ -37,6 +36,7 @@ import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.trycatch.CatchAttr;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.utils.blocks.BlockSet;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class BlockUtils {
@@ -483,33 +483,37 @@ public class BlockUtils {
public static List<BlockNode> collectAllSuccessors(MethodNode mth, BlockNode startBlock, boolean clean) {
List<BlockNode> list = new ArrayList<>(mth.getBasicBlocks().size());
dfsVisit(mth, startBlock, clean, list::add);
Function<BlockNode, List<BlockNode>> nextFunc = clean ? BlockNode::getCleanSuccessors : BlockNode::getSuccessors;
visitDFS(mth, startBlock, nextFunc, list::add);
return list;
}
public static void dfsVisit(MethodNode mth, Consumer<BlockNode> visitor) {
dfsVisit(mth, mth.getEnterBlock(), false, visitor);
public static void visitDFS(MethodNode mth, Consumer<BlockNode> visitor) {
visitDFS(mth, mth.getEnterBlock(), BlockNode::getSuccessors, visitor);
}
private static void dfsVisit(MethodNode mth, BlockNode startBlock, boolean clean, Consumer<BlockNode> visitor) {
BitSet visited = newBlocksBitSet(mth);
public static void visitReverseDFS(MethodNode mth, Consumer<BlockNode> visitor) {
visitDFS(mth, mth.getExitBlock(), BlockNode::getPredecessors, visitor);
}
private static void visitDFS(MethodNode mth, BlockNode startBlock,
Function<BlockNode, List<BlockNode>> nextFunc, Consumer<BlockNode> visitor) {
BlockSet visited = new BlockSet(mth);
Deque<BlockNode> queue = new ArrayDeque<>();
queue.addLast(startBlock);
visited.set(startBlock.getId());
visited.set(startBlock);
while (true) {
BlockNode current = queue.pollLast();
if (current == null) {
return;
}
visitor.accept(current);
List<BlockNode> successors = clean ? current.getCleanSuccessors() : current.getSuccessors();
int count = successors.size();
List<BlockNode> nextBlocks = nextFunc.apply(current);
int count = nextBlocks.size();
for (int i = count - 1; i >= 0; i--) { // to preserve order in queue
BlockNode next = successors.get(i);
int nextId = next.getId();
if (!visited.get(nextId)) {
BlockNode next = nextBlocks.get(i);
if (!visited.checkAndSet(next)) {
queue.addLast(next);
visited.set(nextId);
}
}
}
@@ -1156,104 +1160,6 @@ public class BlockUtils {
return false;
}
public static Map<BlockNode, BitSet> calcPostDominance(MethodNode mth) {
return calcPartialPostDominance(mth, mth.getBasicBlocks(), mth.getPreExitBlocks().get(0));
}
public static Map<BlockNode, BitSet> calcPartialPostDominance(MethodNode mth, Collection<BlockNode> blockNodes, BlockNode exitBlock) {
int blocksCount = mth.getBasicBlocks().size();
Map<BlockNode, BitSet> map = new HashMap<>(blocksCount);
BitSet initSet = new BitSet(blocksCount);
for (BlockNode block : blockNodes) {
initSet.set(block.getId());
}
for (BlockNode block : blockNodes) {
BitSet postDoms = new BitSet(blocksCount);
postDoms.or(initSet);
map.put(block, postDoms);
}
BitSet exitBitSet = map.get(exitBlock);
exitBitSet.clear();
exitBitSet.set(exitBlock.getId());
BitSet domSet = new BitSet(blocksCount);
boolean changed;
do {
changed = false;
for (BlockNode block : blockNodes) {
if (block == exitBlock) {
continue;
}
BitSet d = map.get(block);
if (!changed) {
domSet.clear();
domSet.or(d);
}
for (BlockNode scc : block.getSuccessors()) {
BitSet scPDoms = map.get(scc);
if (scPDoms != null) {
d.and(scPDoms);
}
}
d.set(block.getId());
if (!changed && !d.equals(domSet)) {
changed = true;
map.put(block, d);
}
}
} while (changed);
blockNodes.forEach(block -> {
BitSet postDoms = map.get(block);
postDoms.clear(block.getId());
if (postDoms.isEmpty()) {
map.put(block, EmptyBitSet.EMPTY);
}
});
return map;
}
@Nullable
public static BlockNode calcImmediatePostDominator(MethodNode mth, BlockNode block) {
BlockNode oneSuccessor = Utils.getOne(block.getSuccessors());
if (oneSuccessor != null) {
return oneSuccessor;
}
return calcImmediatePostDominator(mth, block, calcPostDominance(mth));
}
@Nullable
public static BlockNode calcPartialImmediatePostDominator(MethodNode mth, BlockNode block,
Collection<BlockNode> blockNodes, BlockNode exitBlock) {
BlockNode oneSuccessor = Utils.getOne(block.getSuccessors());
if (oneSuccessor != null) {
return oneSuccessor;
}
Map<BlockNode, BitSet> pDomsMap = calcPartialPostDominance(mth, blockNodes, exitBlock);
return calcImmediatePostDominator(mth, block, pDomsMap);
}
@Nullable
public static BlockNode calcImmediatePostDominator(MethodNode mth, BlockNode block, Map<BlockNode, BitSet> postDomsMap) {
BlockNode oneSuccessor = Utils.getOne(block.getSuccessors());
if (oneSuccessor != null) {
return oneSuccessor;
}
List<BlockNode> basicBlocks = mth.getBasicBlocks();
BitSet postDoms = postDomsMap.get(block);
BitSet bs = copyBlocksBitSet(mth, postDoms);
for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
BlockNode pdomBlock = basicBlocks.get(i);
BitSet pdoms = postDomsMap.get(pdomBlock);
if (pdoms != null) {
bs.andNot(pdoms);
}
}
return bitSetToOneBlock(mth, bs);
}
public static BlockNode getTopSplitterForHandler(BlockNode handlerBlock) {
BlockNode block = getBlockWithFlag(handlerBlock.getPredecessors(), AFlag.EXC_TOP_SPLITTER);
if (block == null) {
@@ -5,6 +5,7 @@ import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
import static org.assertj.core.api.Assertions.catchThrowable;
public class TestSwitchWithThrow extends IntegrationTest {
@@ -19,19 +20,19 @@ public class TestSwitchWithThrow extends IntegrationTest {
default:
throw new IllegalStateException("Other");
}
} else {
System.out.println("0");
return -1;
}
System.out.println("0");
return -1;
}
public void check() {
assertThat(test(0)).isEqualTo(-1);
// TODO: implement 'invoke-custom' support
// assertThat(catchThrowable(() -> test(1)))
// .isInstanceOf(IllegalStateException.class).hasMessageContaining("1");
// assertThat(catchThrowable(() -> test(3)))
// .isInstanceOf(IllegalStateException.class).hasMessageContaining("Other");
assertThat(catchThrowable(() -> test(1)))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("1");
assertThat(catchThrowable(() -> test(3)))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("Other");
}
}