fix: improve BlockSet class, use it in more places

This commit is contained in:
Skylot
2025-03-19 22:01:47 +00:00
parent d1a3935c9e
commit dfa6a83f7c
15 changed files with 175 additions and 75 deletions
@@ -25,10 +25,9 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
private final int cid;
/**
* ID linked to position in blocks list (easier to use BitSet)
* TODO: rename to avoid confusion
* Position in blocks list (easier to use BitSet)
*/
private int id;
private int pos;
/**
* Offset in methods bytecode
@@ -71,9 +70,9 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
*/
private List<BlockNode> dominatesOn = new ArrayList<>(3);
public BlockNode(int cid, int id, int offset) {
public BlockNode(int cid, int pos, int offset) {
this.cid = cid;
this.id = id;
this.pos = pos;
this.startOffset = offset;
}
@@ -81,12 +80,20 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
return cid;
}
void setId(int id) {
this.id = id;
void setPos(int id) {
this.pos = id;
}
/**
* Deprecated. Use {@link #getPos()}.
*/
@Deprecated
public int getId() {
return id;
return pos;
}
public int getPos() {
return pos;
}
public List<BlockNode> getPredecessors() {
@@ -105,6 +112,13 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
cleanSuccessors = cleanSuccessors(this);
}
public static void updateBlockPositions(List<BlockNode> blocks) {
int count = blocks.size();
for (int i = 0; i < count; i++) {
blocks.get(i).setPos(i);
}
}
public void lock() {
try {
List<BlockNode> successorsList = successors;
@@ -161,7 +175,7 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
* Check if 'block' dominated on this node
*/
public boolean isDominator(BlockNode block) {
return doms.get(block.getId());
return doms.get(block.getPos());
}
/**
@@ -236,7 +250,7 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
@Override
public int hashCode() {
return startOffset;
return cid;
}
@Override
@@ -248,7 +262,7 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
return false;
}
BlockNode other = (BlockNode) obj;
return cid == other.cid && startOffset == other.startOffset;
return cid == other.cid;
}
@Override
@@ -258,11 +272,11 @@ public final class BlockNode extends AttrNode implements IBlock, Comparable<Bloc
@Override
public String baseString() {
return Integer.toString(id);
return Integer.toString(cid);
}
@Override
public String toString() {
return "B:" + id + ':' + InsnUtils.formatOffset(startOffset);
return "B:" + cid + ':' + InsnUtils.formatOffset(startOffset);
}
}
@@ -366,14 +366,11 @@ public class MethodNode extends NotificationAttrNode implements IMethodDetails,
public void setBasicBlocks(List<BlockNode> blocks) {
this.blocks = blocks;
updateBlockIds(blocks);
updateBlockPositions();
}
public void updateBlockIds(List<BlockNode> blocks) {
int count = blocks.size();
for (int i = 0; i < count; i++) {
blocks.get(i).setId(i);
}
public void updateBlockPositions() {
BlockNode.updateBlockPositions(blocks);
}
public int getNextBlockCId() {
@@ -68,11 +68,11 @@ public abstract class ConditionRegion extends AbstractRegion implements IConditi
}
/**
* Prefer way for update condition info
* Preferred way to update condition info
*/
public void updateCondition(IfInfo info) {
this.condition = info.getCondition();
this.conditionBlocks = info.getMergedBlocks();
this.conditionBlocks = info.getMergedBlocks().toList();
}
public void updateCondition(IfCondition condition, List<BlockNode> conditionBlocks) {
@@ -8,11 +8,12 @@ import java.util.Set;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.blocks.BlockSet;
public final class IfInfo {
private final MethodNode mth;
private final IfCondition condition;
private final List<BlockNode> mergedBlocks;
private final BlockSet mergedBlocks;
private final BlockNode thenBlock;
private final BlockNode elseBlock;
private final Set<BlockNode> skipBlocks;
@@ -20,7 +21,7 @@ public final class IfInfo {
private BlockNode outBlock;
public IfInfo(MethodNode mth, IfCondition condition, BlockNode thenBlock, BlockNode elseBlock) {
this(mth, condition, thenBlock, elseBlock, new ArrayList<>(), new HashSet<>(), new ArrayList<>());
this(mth, condition, thenBlock, elseBlock, BlockSet.empty(mth), new HashSet<>(), new ArrayList<>());
}
public IfInfo(IfInfo info, BlockNode thenBlock, BlockNode elseBlock) {
@@ -29,7 +30,7 @@ public final class IfInfo {
}
private IfInfo(MethodNode mth, IfCondition condition, BlockNode thenBlock, BlockNode elseBlock,
List<BlockNode> mergedBlocks, Set<BlockNode> skipBlocks, List<InsnNode> forceInlineInsns) {
BlockSet mergedBlocks, Set<BlockNode> skipBlocks, List<InsnNode> forceInlineInsns) {
this.mth = mth;
this.condition = condition;
this.thenBlock = thenBlock;
@@ -56,7 +57,11 @@ public final class IfInfo {
@Deprecated
public BlockNode getFirstIfBlock() {
return mergedBlocks.get(0);
return mergedBlocks.getFirst();
}
public BlockSet getMergedBlocks() {
return mergedBlocks;
}
public MethodNode getMth() {
@@ -67,10 +72,6 @@ public final class IfInfo {
return condition;
}
public List<BlockNode> getMergedBlocks() {
return mergedBlocks;
}
public Set<BlockNode> getSkipBlocks() {
return skipBlocks;
}
@@ -89,7 +89,7 @@ public class DotGraphVisitor extends AbstractVisitor {
public void process(MethodNode mth) {
dot.startLine("digraph \"CFG for");
dot.add(escape(mth.getParentClass() + "." + mth.getMethodInfo().getShortId()));
dot.add(escape(mth.getMethodInfo().getFullId()));
dot.add("\" {");
BlockNode enterBlock = mth.getEnterBlock();
@@ -204,7 +204,7 @@ public class DotGraphVisitor extends AbstractVisitor {
dot.add("color=red,");
}
dot.add("label=\"{");
dot.add(String.valueOf(block.getId())).add("\\:\\ ");
dot.add(String.valueOf(block.getCId())).add("\\:\\ ");
dot.add(InsnUtils.formatOffset(block.getStartOffset()));
if (!attrs.isEmpty()) {
dot.add('|').add(attrs);
@@ -237,10 +237,10 @@ public class DotGraphVisitor extends AbstractVisitor {
if (PRINT_DOMINATORS) {
for (BlockNode c : block.getDominatesOn()) {
conn.startLine(block.getId() + " -> " + c.getId() + "[color=green];");
conn.startLine(block.getCId() + " -> " + c.getCId() + "[color=green];");
}
for (BlockNode dom : BlockUtils.bitSetToBlocks(mth, block.getDomFrontier())) {
conn.startLine("f_" + block.getId() + " -> f_" + dom.getId() + "[color=blue];");
conn.startLine("f_" + block.getCId() + " -> f_" + dom.getCId() + "[color=blue];");
}
}
}
@@ -280,7 +280,7 @@ public class DotGraphVisitor extends AbstractVisitor {
private String makeName(IContainer c) {
String name;
if (c instanceof BlockNode) {
name = "Node_" + ((BlockNode) c).getId();
name = "Node_" + ((BlockNode) c).getCId();
} else if (c instanceof IBlock) {
name = "Node_" + c.getClass().getSimpleName() + '_' + c.hashCode();
} else {
@@ -68,7 +68,7 @@ public class BlockExceptionHandler {
BlockProcessor.removeMarkedBlocks(mth);
BlockSet sorted = new BlockSet(mth);
BlockUtils.visitDFS(mth, sorted::set);
BlockUtils.visitDFS(mth, sorted::add);
removeUnusedExcHandlers(mth, tryBlocks, sorted);
return true;
}
@@ -623,7 +623,7 @@ public class BlockExceptionHandler {
for (ExceptionHandler eh : mth.getExceptionHandlers()) {
boolean notProcessed = true;
BlockNode handlerBlock = eh.getHandlerBlock();
if (handlerBlock == null || blocks.get(handlerBlock)) {
if (handlerBlock == null || blocks.contains(handlerBlock)) {
continue;
}
for (TryCatchBlockAttr tcb : tryBlocks) {
@@ -103,7 +103,7 @@ public class BlockProcessor extends AbstractVisitor {
* - post dominators (only if {@link AFlag#COMPUTE_POST_DOM} added to method)
* - loops and nested loop info
* </pre>
*
* <p>
* This method should be called after changing a block tree in custom passes added before
* {@link BlockFinisher}.
*/
@@ -679,7 +679,7 @@ public class BlockProcessor extends AbstractVisitor {
}
public static void removeMarkedBlocks(MethodNode mth) {
mth.getBasicBlocks().removeIf(block -> {
boolean removed = mth.getBasicBlocks().removeIf(block -> {
if (block.contains(AFlag.REMOVE)) {
if (!block.getPredecessors().isEmpty() || !block.getSuccessors().isEmpty()) {
LOG.warn("Block {} not deleted, method: {}", block, mth);
@@ -693,6 +693,9 @@ public class BlockProcessor extends AbstractVisitor {
}
return false;
});
if (removed) {
mth.updateBlockPositions();
}
}
private static void removeUnreachableBlocks(MethodNode mth) {
@@ -728,6 +731,7 @@ public class BlockProcessor extends AbstractVisitor {
toRemove.forEach(BlockSplitter::detachBlock);
mth.getBasicBlocks().removeAll(toRemove);
mth.updateBlockPositions();
}
private static void clearBlocksState(MethodNode mth) {
@@ -79,6 +79,7 @@ public class BlockSplitter extends AbstractVisitor {
addTempConnectionsForExcHandlers(mth, blocksMap);
setupExitConnections(mth);
mth.updateBlockPositions();
mth.unloadInsnArr();
}
@@ -20,14 +20,14 @@ public class PostDominatorTree {
int mthBlocksCount = mth.getBasicBlocks().size();
List<BlockNode> sorted = new ArrayList<>(mthBlocksCount);
BlockUtils.visitReverseDFS(mth, sorted::add);
// temporary set block ids to match reverse sorted order
// save old ids for later remapping
// temporary set block positions to match reverse sorted order
// save old positions for later remapping
int blocksCount = sorted.size();
int[] ids = new int[mthBlocksCount];
int[] posMapping = new int[mthBlocksCount];
for (int i = 0; i < blocksCount; i++) {
ids[i] = sorted.get(i).getId();
posMapping[i] = sorted.get(i).getPos();
}
mth.updateBlockIds(sorted);
BlockNode.updateBlockPositions(sorted);
BlockNode[] postDoms = DominatorTree.build(sorted, BlockNode::getSuccessors);
BlockNode firstBlock = sorted.get(0);
@@ -43,8 +43,8 @@ public class PostDominatorTree {
for (int i = 1; i < blocksCount; i++) {
BlockNode block = sorted.get(i);
BitSet bs = new BitSet(blocksCount);
block.getPostDoms().stream().forEach(n -> bs.set(ids[n]));
bs.clear(ids[i]);
block.getPostDoms().stream().forEach(n -> bs.set(posMapping[n]));
bs.clear(posMapping[i]);
block.setPostDoms(bs);
}
// check for missing blocks in 'sorted' list
@@ -65,8 +65,8 @@ public class PostDominatorTree {
// show error as a warning because this info not always used
mth.addWarnComment("Failed to build post-dominance tree", e);
} finally {
// revert block ids change
mth.updateBlockIds(mth.getBasicBlocks());
// revert block positions change
mth.updateBlockPositions();
}
}
}
@@ -29,6 +29,7 @@ import jadx.core.dex.regions.conditions.IfInfo;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.blocks.BlockSet;
import jadx.core.utils.exceptions.JadxRuntimeException;
import static jadx.core.utils.BlockUtils.isEqualPaths;
@@ -228,7 +229,7 @@ final class IfRegionMaker {
private static boolean allPathsFromIf(BlockNode block, IfInfo info) {
List<BlockNode> preds = block.getPredecessors();
List<BlockNode> ifBlocks = info.getMergedBlocks();
BlockSet ifBlocks = info.getMergedBlocks();
for (BlockNode pred : preds) {
if (pred.contains(AFlag.LOOP_END)) {
// ignore loop back edge
@@ -77,7 +77,7 @@ final class LoopRegionMaker {
loopRegion.updateCondition(condInfo);
// prevent if's merge with loop condition
condInfo.getMergedBlocks().forEach(b -> b.add(AFlag.ADDED_TO_REGION));
exitBlocks.removeAll(condInfo.getMergedBlocks());
exitBlocks.removeAll(condInfo.getMergedBlocks().toList());
if (!exitBlocks.isEmpty()) {
BlockNode loopExit = condInfo.getElseBlock();
@@ -1,7 +1,6 @@
package jadx.core.dex.visitors.regions.maker;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.Objects;
@@ -19,6 +18,7 @@ import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.blocks.BlockSet;
import jadx.core.utils.exceptions.JadxOverflowException;
import static jadx.core.utils.BlockUtils.getNextBlock;
@@ -30,7 +30,7 @@ public class RegionMaker {
private final IfRegionMaker ifMaker;
private final LoopRegionMaker loopMaker;
private final BitSet processedBlocks;
private final BlockSet processedBlocks;
private final int regionsLimit;
private int regionsCount;
@@ -40,9 +40,8 @@ public class RegionMaker {
this.stack = new RegionStack(mth);
this.ifMaker = new IfRegionMaker(mth, this);
this.loopMaker = new LoopRegionMaker(mth, this, ifMaker);
int blocksCount = mth.getBasicBlocks().size();
this.processedBlocks = new BitSet(blocksCount);
this.regionsLimit = blocksCount * 100;
this.processedBlocks = BlockSet.empty(mth);
this.regionsLimit = mth.getBasicBlocks().size() * 100;
}
public Region makeMthRegion() {
@@ -57,12 +56,10 @@ public class RegionMaker {
return region;
}
int startBlockId = startBlock.getId();
if (processedBlocks.get(startBlockId)) {
if (processedBlocks.addChecked(startBlock)) {
mth.addWarn("Removed duplicated region for block: " + startBlock + ' ' + startBlock.getAttributesString());
return region;
}
processedBlocks.set(startBlockId);
BlockNode next = startBlock;
while (next != null) {
@@ -159,10 +156,10 @@ public class RegionMaker {
}
boolean isProcessed(BlockNode block) {
return processedBlocks.get(block.getId());
return processedBlocks.contains(block);
}
void clearBlockProcessedState(BlockNode block) {
processedBlocks.clear(block.getId());
processedBlocks.remove(block);
}
}
@@ -443,7 +443,7 @@ public class BlockUtils {
*/
public static @Nullable BlockNode getPrevBlockOnPath(MethodNode mth, BlockNode block, BlockNode pathStart) {
BlockSet preds = BlockSet.from(mth, block.getPredecessors());
if (preds.get(pathStart)) {
if (preds.contains(pathStart)) {
return pathStart;
}
DFSIteration dfs = new DFSIteration(mth, pathStart, BlockNode::getCleanSuccessors);
@@ -452,7 +452,7 @@ public class BlockUtils {
if (next == null) {
return null;
}
if (preds.get(next)) {
if (preds.contains(next)) {
return next;
}
}
@@ -4,20 +4,32 @@ import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.function.Consumer;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.EmptyBitSet;
public class BlockSet {
/**
* BlockNode set implementation based on BitSet.
*/
public class BlockSet implements Iterable<BlockNode> {
public static BlockSet empty(MethodNode mth) {
return new BlockSet(mth);
}
public static BlockSet from(MethodNode mth, Collection<BlockNode> blocks) {
BlockSet newBS = new BlockSet(mth);
newBS.set(blocks);
newBS.addAll(blocks);
return newBS;
}
@@ -29,28 +41,49 @@ public class BlockSet {
this.bs = new BitSet(mth.getBasicBlocks().size());
}
public boolean get(BlockNode block) {
return bs.get(block.getId());
public boolean contains(BlockNode block) {
return bs.get(block.getPos());
}
public void set(BlockNode block) {
bs.set(block.getId());
public void add(BlockNode block) {
bs.set(block.getPos());
}
public void set(Collection<BlockNode> blocks) {
blocks.forEach(this::set);
public void addAll(Collection<BlockNode> blocks) {
blocks.forEach(this::add);
}
public boolean checkAndSet(BlockNode block) {
int id = block.getId();
public void addAll(BlockSet otherBlockSet) {
bs.or(otherBlockSet.bs);
}
public void remove(BlockNode block) {
bs.clear(block.getPos());
}
public void remove(Collection<BlockNode> blocks) {
blocks.forEach(this::remove);
}
public boolean addChecked(BlockNode block) {
int id = block.getPos();
boolean state = bs.get(id);
bs.set(id);
return state;
}
public boolean containsAll(List<BlockNode> blocks) {
for (BlockNode block : blocks) {
if (!contains(block)) {
return false;
}
}
return true;
}
public boolean intersects(List<BlockNode> blocks) {
for (BlockNode block : blocks) {
if (get(block)) {
if (contains(block)) {
return true;
}
}
@@ -67,13 +100,17 @@ public class BlockSet {
}
public boolean isEmpty() {
return bs.cardinality() == 0;
return bs.isEmpty();
}
public int size() {
return bs.cardinality();
}
public void remove() {
bs.clear();
}
public @Nullable BlockNode getOne() {
if (bs.cardinality() == 1) {
return mth.getBasicBlocks().get(bs.nextSetBit(0));
@@ -81,6 +118,11 @@ public class BlockSet {
return null;
}
public BlockNode getFirst() {
return mth.getBasicBlocks().get(bs.nextSetBit(0));
}
@Override
public void forEach(Consumer<? super BlockNode> consumer) {
if (bs.isEmpty()) {
return;
@@ -91,6 +133,18 @@ public class BlockSet {
}
}
@Override
public @NotNull Iterator<BlockNode> iterator() {
return new BlockSetIterator(bs, size(), mth.getBasicBlocks());
}
@Override
public Spliterator<BlockNode> spliterator() {
int size = size();
BlockSetIterator iterator = new BlockSetIterator(bs, size, mth.getBasicBlocks());
return Spliterators.spliterator(iterator, size, Spliterator.ORDERED | Spliterator.DISTINCT);
}
public List<BlockNode> toList() {
if (bs == null || bs == EmptyBitSet.EMPTY) {
return Collections.emptyList();
@@ -111,4 +165,35 @@ public class BlockSet {
public String toString() {
return toList().toString();
}
private static final class BlockSetIterator implements Iterator<BlockNode> {
private final BitSet bs;
private final int size;
private final List<BlockNode> blocks;
private int cursor;
private int start;
public BlockSetIterator(BitSet bs, int size, List<BlockNode> blocks) {
this.bs = bs;
this.size = size;
this.blocks = blocks;
}
@Override
public boolean hasNext() {
return cursor != size;
}
@Override
public BlockNode next() {
int pos = bs.nextSetBit(start);
if (pos == -1) {
throw new NoSuchElementException();
}
start = pos + 1;
cursor++;
return blocks.get(pos);
}
}
}
@@ -20,7 +20,7 @@ public class DFSIteration {
queue = new ArrayDeque<>();
visited = new BlockSet(mth);
queue.addLast(startBlock);
visited.set(startBlock);
visited.add(startBlock);
}
public @Nullable BlockNode next() {
@@ -32,7 +32,7 @@ public class DFSIteration {
int count = nextBlocks.size();
for (int i = count - 1; i >= 0; i--) { // to preserve order in queue
BlockNode next = nextBlocks.get(i);
if (!visited.checkAndSet(next)) {
if (!visited.addChecked(next)) {
queue.addLast(next);
}
}