refactor: split region maker

This commit is contained in:
Skylot
2024-09-22 20:10:37 +01:00
parent 8f27de4d0e
commit 9c30aeacdb
11 changed files with 1548 additions and 1365 deletions
@@ -0,0 +1,131 @@
package jadx.core.dex.visitors.regions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.EdgeInsnAttr;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnContainer;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.utils.RegionUtils;
final class PostProcessRegions extends AbstractRegionVisitor {
private static final Logger LOG = LoggerFactory.getLogger(PostProcessRegions.class);
private static final IRegionVisitor INSTANCE = new PostProcessRegions();
static void process(MethodNode mth) {
DepthRegionTraversal.traverse(mth, INSTANCE);
}
@Override
public void leaveRegion(MethodNode mth, IRegion region) {
if (region instanceof LoopRegion) {
// merge conditions in loops
LoopRegion loop = (LoopRegion) region;
loop.mergePreCondition();
} else if (region instanceof SwitchRegion) {
// insert 'break' in switch cases (run after try/catch insertion)
processSwitch(mth, (SwitchRegion) region);
} else if (region instanceof Region) {
insertEdgeInsn((Region) region);
}
}
/**
* Insert insn block from edge insn attribute.
*/
private static void insertEdgeInsn(Region region) {
List<IContainer> subBlocks = region.getSubBlocks();
if (subBlocks.isEmpty()) {
return;
}
IContainer last = subBlocks.get(subBlocks.size() - 1);
List<EdgeInsnAttr> edgeInsnAttrs = last.getAll(AType.EDGE_INSN);
if (edgeInsnAttrs.isEmpty()) {
return;
}
EdgeInsnAttr insnAttr = edgeInsnAttrs.get(0);
if (!insnAttr.getStart().equals(last)) {
return;
}
if (last instanceof BlockNode) {
BlockNode block = (BlockNode) last;
if (block.getInstructions().isEmpty()) {
block.getInstructions().add(insnAttr.getInsn());
return;
}
}
List<InsnNode> insns = Collections.singletonList(insnAttr.getInsn());
region.add(new InsnContainer(insns));
}
private static void processSwitch(MethodNode mth, SwitchRegion sw) {
for (IContainer c : sw.getBranches()) {
if (c instanceof Region) {
Set<IBlock> blocks = new HashSet<>();
RegionUtils.getAllRegionBlocks(c, blocks);
if (blocks.isEmpty()) {
addBreakToContainer((Region) c);
} else {
for (IBlock block : blocks) {
if (block instanceof BlockNode) {
addBreakForBlock(mth, c, blocks, (BlockNode) block);
}
}
}
}
}
}
private static void addBreakToContainer(Region c) {
if (RegionUtils.hasExitEdge(c)) {
return;
}
List<InsnNode> insns = new ArrayList<>(1);
insns.add(new InsnNode(InsnType.BREAK, 0));
c.add(new InsnContainer(insns));
}
private static void addBreakForBlock(MethodNode mth, IContainer c, Set<IBlock> blocks, BlockNode bn) {
for (BlockNode s : bn.getCleanSuccessors()) {
if (!blocks.contains(s)
&& !bn.contains(AFlag.ADDED_TO_REGION)
&& !s.contains(AFlag.FALL_THROUGH)) {
addBreak(mth, c, bn);
return;
}
}
}
private static void addBreak(MethodNode mth, IContainer c, BlockNode bn) {
IContainer blockContainer = RegionUtils.getBlockContainer(c, bn);
if (blockContainer instanceof Region) {
addBreakToContainer((Region) blockContainer);
} else if (c instanceof Region) {
addBreakToContainer((Region) c);
} else {
LOG.warn("Can't insert break, container: {}, block: {}, mth: {}", blockContainer, bn, mth);
}
}
private PostProcessRegions() {
// singleton
}
}
File diff suppressed because it is too large Load Diff
@@ -1,43 +1,20 @@
package jadx.core.dex.visitors.regions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.EdgeInsnAttr;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnContainer;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.SynchronizedRegion;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.regions.maker.ExcHandlersRegionMaker;
import jadx.core.dex.visitors.regions.maker.RegionMaker;
import jadx.core.dex.visitors.regions.maker.SynchronizedRegionMaker;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.RegionUtils;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxException;
/**
* Pack blocks into regions for code generation
*/
@JadxVisitor(
name = "RegionMakerVisitor",
desc = "Pack blocks into regions for code generation"
)
public class RegionMakerVisitor extends AbstractVisitor {
private static final Logger LOG = LoggerFactory.getLogger(RegionMakerVisitor.class);
private static final IRegionVisitor POST_REGION_VISITOR = new PostRegionVisitor();
@Override
public void visit(MethodNode mth) throws JadxException {
@@ -45,33 +22,16 @@ public class RegionMakerVisitor extends AbstractVisitor {
return;
}
RegionMaker rm = new RegionMaker(mth);
RegionStack state = new RegionStack(mth);
// fill region structure
BlockNode startBlock = Utils.first(mth.getEnterBlock().getCleanSuccessors());
mth.setRegion(rm.makeRegion(startBlock, state));
mth.setRegion(rm.makeMthRegion());
if (!mth.isNoExceptionHandlers()) {
IRegion expOutBlock = rm.processTryCatchBlocks(mth);
if (expOutBlock != null) {
mth.getRegion().add(expOutBlock);
}
new ExcHandlersRegionMaker(mth, rm).process();
}
postProcessRegions(mth);
}
private static void postProcessRegions(MethodNode mth) {
processForceInlineInsns(mth);
// make try-catch regions
ProcessTryCatchRegions.process(mth);
DepthRegionTraversal.traverse(mth, POST_REGION_VISITOR);
PostProcessRegions.process(mth);
CleanRegions.process(mth);
if (mth.getAccessFlags().isSynchronized()) {
removeSynchronized(mth);
SynchronizedRegionMaker.removeSynchronized(mth);
}
}
@@ -84,120 +44,8 @@ public class RegionMakerVisitor extends AbstractVisitor {
}
}
private static final class PostRegionVisitor extends AbstractRegionVisitor {
@Override
public void leaveRegion(MethodNode mth, IRegion region) {
if (region instanceof LoopRegion) {
// merge conditions in loops
LoopRegion loop = (LoopRegion) region;
loop.mergePreCondition();
} else if (region instanceof SwitchRegion) {
// insert 'break' in switch cases (run after try/catch insertion)
processSwitch(mth, (SwitchRegion) region);
} else if (region instanceof Region) {
insertEdgeInsn((Region) region);
}
}
/**
* Insert insn block from edge insn attribute.
*/
private static void insertEdgeInsn(Region region) {
List<IContainer> subBlocks = region.getSubBlocks();
if (subBlocks.isEmpty()) {
return;
}
IContainer last = subBlocks.get(subBlocks.size() - 1);
List<EdgeInsnAttr> edgeInsnAttrs = last.getAll(AType.EDGE_INSN);
if (edgeInsnAttrs.isEmpty()) {
return;
}
EdgeInsnAttr insnAttr = edgeInsnAttrs.get(0);
if (!insnAttr.getStart().equals(last)) {
return;
}
if (last instanceof BlockNode) {
BlockNode block = (BlockNode) last;
if (block.getInstructions().isEmpty()) {
block.getInstructions().add(insnAttr.getInsn());
return;
}
}
List<InsnNode> insns = Collections.singletonList(insnAttr.getInsn());
region.add(new InsnContainer(insns));
}
private static void processSwitch(MethodNode mth, SwitchRegion sw) {
for (IContainer c : sw.getBranches()) {
if (c instanceof Region) {
Set<IBlock> blocks = new HashSet<>();
RegionUtils.getAllRegionBlocks(c, blocks);
if (blocks.isEmpty()) {
addBreakToContainer((Region) c);
} else {
for (IBlock block : blocks) {
if (block instanceof BlockNode) {
addBreakForBlock(mth, c, blocks, (BlockNode) block);
}
}
}
}
}
}
private static void addBreakToContainer(Region c) {
if (RegionUtils.hasExitEdge(c)) {
return;
}
List<InsnNode> insns = new ArrayList<>(1);
insns.add(new InsnNode(InsnType.BREAK, 0));
c.add(new InsnContainer(insns));
}
private static void addBreakForBlock(MethodNode mth, IContainer c, Set<IBlock> blocks, BlockNode bn) {
for (BlockNode s : bn.getCleanSuccessors()) {
if (!blocks.contains(s)
&& !bn.contains(AFlag.ADDED_TO_REGION)
&& !s.contains(AFlag.FALL_THROUGH)) {
addBreak(mth, c, bn);
return;
}
}
}
private static void addBreak(MethodNode mth, IContainer c, BlockNode bn) {
IContainer blockContainer = RegionUtils.getBlockContainer(c, bn);
if (blockContainer instanceof Region) {
addBreakToContainer((Region) blockContainer);
} else if (c instanceof Region) {
addBreakToContainer((Region) c);
} else {
LOG.warn("Can't insert break, container: {}, block: {}, mth: {}", blockContainer, bn, mth);
}
}
}
private static void removeSynchronized(MethodNode mth) {
Region startRegion = mth.getRegion();
List<IContainer> subBlocks = startRegion.getSubBlocks();
if (!subBlocks.isEmpty() && subBlocks.get(0) instanceof SynchronizedRegion) {
SynchronizedRegion synchRegion = (SynchronizedRegion) subBlocks.get(0);
InsnNode synchInsn = synchRegion.getEnterInsn();
if (!synchInsn.getArg(0).isThis()) {
LOG.warn("In synchronized method {}, top region not synchronized by 'this' {}", mth, synchInsn);
return;
}
// replace synchronized block with inner region
startRegion.getSubBlocks().set(0, synchRegion.getRegion());
// remove 'monitor-enter' instruction
InsnRemover.remove(mth, synchInsn);
// remove 'monitor-exit' instruction
for (InsnNode exit : synchRegion.getExitInsns()) {
InsnRemover.remove(mth, exit);
}
// run region cleaner again
CleanRegions.process(mth);
// assume that CodeShrinker will be run after this
}
@Override
public String getName() {
return "RegionMakerVisitor";
}
}
@@ -0,0 +1,153 @@
package jadx.core.dex.visitors.regions.maker;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.trycatch.ExcHandlerAttr;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.trycatch.TryCatchBlockAttr;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.RegionUtils;
public class ExcHandlersRegionMaker {
private final MethodNode mth;
private final RegionMaker regionMaker;
public ExcHandlersRegionMaker(MethodNode mth, RegionMaker regionMaker) {
this.mth = mth;
this.regionMaker = regionMaker;
}
public void process() {
if (mth.isNoExceptionHandlers()) {
return;
}
IRegion excOutBlock = collectHandlerRegions();
if (excOutBlock != null) {
mth.getRegion().add(excOutBlock);
}
}
private @Nullable IRegion collectHandlerRegions() {
List<TryCatchBlockAttr> tcs = mth.getAll(AType.TRY_BLOCKS_LIST);
for (TryCatchBlockAttr tc : tcs) {
List<BlockNode> blocks = new ArrayList<>(tc.getHandlersCount());
Set<BlockNode> splitters = new HashSet<>();
for (ExceptionHandler handler : tc.getHandlers()) {
BlockNode handlerBlock = handler.getHandlerBlock();
if (handlerBlock != null) {
blocks.add(handlerBlock);
splitters.add(BlockUtils.getTopSplitterForHandler(handlerBlock));
} else {
mth.addDebugComment("No exception handler block: " + handler);
}
}
Set<BlockNode> exits = new HashSet<>();
for (BlockNode splitter : splitters) {
for (BlockNode handler : blocks) {
if (handler.contains(AFlag.REMOVE)) {
continue;
}
List<BlockNode> s = splitter.getSuccessors();
if (s.isEmpty()) {
mth.addDebugComment("No successors for splitter: " + splitter);
continue;
}
BlockNode ss = s.get(0);
BlockNode cross = BlockUtils.getPathCross(mth, ss, handler);
if (cross != null && cross != ss && cross != handler) {
exits.add(cross);
}
}
}
for (ExceptionHandler handler : tc.getHandlers()) {
processExcHandler(handler, exits);
}
}
return processHandlersOutBlocks(tcs);
}
/**
* Search handlers successor blocks aren't included in any region.
*/
private @Nullable IRegion processHandlersOutBlocks(List<TryCatchBlockAttr> tcs) {
Set<IBlock> allRegionBlocks = new HashSet<>();
RegionUtils.getAllRegionBlocks(mth.getRegion(), allRegionBlocks);
Set<IBlock> successorBlocks = new HashSet<>();
for (TryCatchBlockAttr tc : tcs) {
for (ExceptionHandler handler : tc.getHandlers()) {
IContainer region = handler.getHandlerRegion();
if (region != null) {
IBlock lastBlock = RegionUtils.getLastBlock(region);
if (lastBlock instanceof BlockNode) {
successorBlocks.addAll(((BlockNode) lastBlock).getSuccessors());
}
RegionUtils.getAllRegionBlocks(region, allRegionBlocks);
}
}
}
successorBlocks.removeAll(allRegionBlocks);
if (successorBlocks.isEmpty()) {
return null;
}
RegionStack stack = regionMaker.getStack();
Region excOutRegion = new Region(mth.getRegion());
for (IBlock block : successorBlocks) {
if (block instanceof BlockNode) {
stack.clear();
stack.push(excOutRegion);
excOutRegion.add(regionMaker.makeRegion((BlockNode) block));
}
}
return excOutRegion;
}
private void processExcHandler(ExceptionHandler handler, Set<BlockNode> exits) {
BlockNode start = handler.getHandlerBlock();
if (start == null) {
return;
}
RegionStack stack = regionMaker.getStack().clear();
BlockNode dom;
if (handler.isFinally()) {
dom = BlockUtils.getTopSplitterForHandler(start);
} else {
dom = start;
stack.addExits(exits);
}
if (dom.contains(AFlag.REMOVE)) {
return;
}
BitSet domFrontier = dom.getDomFrontier();
List<BlockNode> handlerExits = BlockUtils.bitSetToBlocks(mth, domFrontier);
boolean inLoop = mth.getLoopForBlock(start) != null;
for (BlockNode exit : handlerExits) {
if ((!inLoop || BlockUtils.isPathExists(start, exit))
&& RegionUtils.isRegionContainsBlock(mth.getRegion(), exit)) {
stack.addExit(exit);
}
}
handler.setHandlerRegion(regionMaker.makeRegion(start));
ExcHandlerAttr excHandlerAttr = start.get(AType.EXC_HANDLER);
if (excHandlerAttr == null) {
mth.addWarn("Missing exception handler attribute for start block: " + start);
} else {
handler.getHandlerRegion().addAttr(excHandlerAttr);
}
}
}
@@ -1,10 +1,11 @@
package jadx.core.dex.visitors.regions;
package jadx.core.dex.visitors.regions.maker;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -12,28 +13,133 @@ import org.slf4j.LoggerFactory;
import jadx.core.Consts;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.EdgeInsnAttr;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.regions.conditions.IfCondition.Mode;
import jadx.core.dex.regions.conditions.IfInfo;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.exceptions.JadxRuntimeException;
import static jadx.core.dex.visitors.regions.RegionMaker.isEqualPaths;
import static jadx.core.dex.visitors.regions.RegionMaker.isEqualReturnBlocks;
import static jadx.core.utils.BlockUtils.isEqualPaths;
import static jadx.core.utils.BlockUtils.isEqualReturnBlocks;
import static jadx.core.utils.BlockUtils.isPathExists;
public class IfMakerHelper {
private static final Logger LOG = LoggerFactory.getLogger(IfMakerHelper.class);
final class IfRegionMaker {
private static final Logger LOG = LoggerFactory.getLogger(IfRegionMaker.class);
private final MethodNode mth;
private final RegionMaker regionMaker;
private IfMakerHelper() {
IfRegionMaker(MethodNode mth, RegionMaker regionMaker) {
this.mth = mth;
this.regionMaker = regionMaker;
}
BlockNode process(IRegion currentRegion, BlockNode block, IfNode ifnode, RegionStack stack) {
if (block.contains(AFlag.ADDED_TO_REGION)) {
// block already included in other 'if' region
return ifnode.getThenBlock();
}
IfInfo currentIf = makeIfInfo(mth, block);
if (currentIf == null) {
return null;
}
IfInfo mergedIf = mergeNestedIfNodes(currentIf);
if (mergedIf != null) {
currentIf = mergedIf;
} else {
// invert simple condition (compiler often do it)
currentIf = IfInfo.invert(currentIf);
}
IfInfo modifiedIf = restructureIf(mth, block, currentIf);
if (modifiedIf != null) {
currentIf = modifiedIf;
} else {
if (currentIf.getMergedBlocks().size() <= 1) {
return null;
}
currentIf = makeIfInfo(mth, block);
currentIf = restructureIf(mth, block, currentIf);
if (currentIf == null) {
// all attempts failed
return null;
}
}
confirmMerge(currentIf);
IfRegion ifRegion = new IfRegion(currentRegion);
ifRegion.updateCondition(currentIf);
currentRegion.getSubBlocks().add(ifRegion);
BlockNode outBlock = currentIf.getOutBlock();
stack.push(ifRegion);
stack.addExit(outBlock);
BlockNode thenBlock = currentIf.getThenBlock();
if (thenBlock == null) {
// empty then block, not normal, but maybe correct
ifRegion.setThenRegion(new Region(ifRegion));
} else {
ifRegion.setThenRegion(regionMaker.makeRegion(thenBlock));
}
BlockNode elseBlock = currentIf.getElseBlock();
if (elseBlock == null || stack.containsExit(elseBlock)) {
ifRegion.setElseRegion(null);
} else {
ifRegion.setElseRegion(regionMaker.makeRegion(elseBlock));
}
// insert edge insns in new 'else' branch
// TODO: make more common algorithm
if (ifRegion.getElseRegion() == null && outBlock != null) {
List<EdgeInsnAttr> edgeInsnAttrs = outBlock.getAll(AType.EDGE_INSN);
if (!edgeInsnAttrs.isEmpty()) {
Region elseRegion = new Region(ifRegion);
for (EdgeInsnAttr edgeInsnAttr : edgeInsnAttrs) {
if (edgeInsnAttr.getEnd().equals(outBlock)) {
addEdgeInsn(currentIf, elseRegion, edgeInsnAttr);
}
}
ifRegion.setElseRegion(elseRegion);
}
}
stack.pop();
return outBlock;
}
@NotNull
IfInfo buildIfInfo(LoopRegion loopRegion) {
IfInfo condInfo = makeIfInfo(mth, loopRegion.getHeader());
condInfo = searchNestedIf(condInfo);
confirmMerge(condInfo);
return condInfo;
}
private void addEdgeInsn(IfInfo ifInfo, Region region, EdgeInsnAttr edgeInsnAttr) {
BlockNode start = edgeInsnAttr.getStart();
boolean fromThisIf = false;
for (BlockNode ifBlock : ifInfo.getMergedBlocks()) {
if (ifBlock.getSuccessors().contains(start)) {
fromThisIf = true;
break;
}
}
if (!fromThisIf) {
return;
}
region.add(start);
}
@Nullable
@@ -262,7 +368,7 @@ public class IfMakerHelper {
return from.getCleanSuccessors().size() == 1 && from.getCleanSuccessors().contains(to);
}
private static IfInfo mergeIfInfo(IfInfo first, IfInfo second, boolean followThenBranch) {
static IfInfo mergeIfInfo(IfInfo first, IfInfo second, boolean followThenBranch) {
MethodNode mth = first.getMth();
Set<BlockNode> skipBlocks = first.getSkipBlocks();
BlockNode thenBlock;
@@ -274,7 +380,7 @@ public class IfMakerHelper {
thenBlock = getBranchBlock(first.getThenBlock(), second.getThenBlock(), skipBlocks, mth);
elseBlock = second.getElseBlock();
}
Mode mergeOperation = followThenBranch ? Mode.AND : Mode.OR;
IfCondition.Mode mergeOperation = followThenBranch ? IfCondition.Mode.AND : IfCondition.Mode.OR;
IfCondition condition = IfCondition.merge(mergeOperation, first.getCondition(), second.getCondition());
IfInfo result = new IfInfo(mth, condition, thenBlock, elseBlock);
result.merge(first, second);
@@ -0,0 +1,464 @@
package jadx.core.dex.visitors.regions.maker;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.EdgeInsnAttr;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.attributes.nodes.LoopLabelAttr;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.Edge;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.conditions.IfInfo;
import jadx.core.dex.regions.loops.LoopRegion;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.ListUtils;
import jadx.core.utils.RegionUtils;
import jadx.core.utils.exceptions.JadxRuntimeException;
import static jadx.core.utils.BlockUtils.getNextBlock;
import static jadx.core.utils.BlockUtils.isPathExists;
final class LoopRegionMaker {
private final MethodNode mth;
private final RegionMaker regionMaker;
private final IfRegionMaker ifMaker;
LoopRegionMaker(MethodNode mth, RegionMaker regionMaker, IfRegionMaker ifMaker) {
this.mth = mth;
this.regionMaker = regionMaker;
this.ifMaker = ifMaker;
}
BlockNode process(IRegion curRegion, LoopInfo loop, RegionStack stack) {
BlockNode loopStart = loop.getStart();
Set<BlockNode> exitBlocksSet = loop.getExitNodes();
// set exit blocks scan order priority
// this can help if loop has several exits (after using 'break' or 'return' in loop)
List<BlockNode> exitBlocks = new ArrayList<>(exitBlocksSet.size());
BlockNode nextStart = getNextBlock(loopStart);
if (nextStart != null && exitBlocksSet.remove(nextStart)) {
exitBlocks.add(nextStart);
}
if (exitBlocksSet.remove(loopStart)) {
exitBlocks.add(loopStart);
}
if (exitBlocksSet.remove(loop.getEnd())) {
exitBlocks.add(loop.getEnd());
}
exitBlocks.addAll(exitBlocksSet);
LoopRegion loopRegion = makeLoopRegion(curRegion, loop, exitBlocks);
if (loopRegion == null) {
BlockNode exit = makeEndlessLoop(curRegion, stack, loop, loopStart);
insertContinue(loop);
return exit;
}
curRegion.getSubBlocks().add(loopRegion);
IRegion outerRegion = stack.peekRegion();
stack.push(loopRegion);
IfInfo condInfo = ifMaker.buildIfInfo(loopRegion);
if (!loop.getLoopBlocks().contains(condInfo.getThenBlock())) {
// invert loop condition if 'then' points to exit
condInfo = IfInfo.invert(condInfo);
}
loopRegion.updateCondition(condInfo);
// prevent if's merge with loop condition
condInfo.getMergedBlocks().forEach(b -> b.add(AFlag.ADDED_TO_REGION));
exitBlocks.removeAll(condInfo.getMergedBlocks());
if (!exitBlocks.isEmpty()) {
BlockNode loopExit = condInfo.getElseBlock();
if (loopExit != null) {
// add 'break' instruction before path cross between main loop exit and sub-exit
for (Edge exitEdge : loop.getExitEdges()) {
if (exitBlocks.contains(exitEdge.getSource())) {
insertLoopBreak(stack, loop, loopExit, exitEdge);
}
}
}
}
BlockNode out;
if (loopRegion.isConditionAtEnd()) {
BlockNode thenBlock = condInfo.getThenBlock();
out = thenBlock == loop.getEnd() || thenBlock == loopStart ? condInfo.getElseBlock() : thenBlock;
out = BlockUtils.followEmptyPath(out);
loopStart.remove(AType.LOOP);
loop.getEnd().add(AFlag.ADDED_TO_REGION);
stack.addExit(loop.getEnd());
regionMaker.clearBlockProcessedState(loopStart);
Region body = regionMaker.makeRegion(loopStart);
loopRegion.setBody(body);
loopStart.addAttr(AType.LOOP, loop);
loop.getEnd().remove(AFlag.ADDED_TO_REGION);
} else {
out = condInfo.getElseBlock();
if (outerRegion != null
&& out != null
&& out.contains(AFlag.LOOP_START)
&& !out.getAll(AType.LOOP).contains(loop)
&& RegionUtils.isRegionContainsBlock(outerRegion, out)) {
// exit to already processed outer loop
out = null;
}
stack.addExit(out);
BlockNode loopBody = condInfo.getThenBlock();
Region body;
if (Objects.equals(loopBody, loopStart)) {
// empty loop body
body = new Region(loopRegion);
} else {
body = regionMaker.makeRegion(loopBody);
}
// add blocks from loop start to first condition block
BlockNode conditionBlock = condInfo.getFirstIfBlock();
if (loopStart != conditionBlock) {
Set<BlockNode> blocks = BlockUtils.getAllPathsBlocks(loopStart, conditionBlock);
blocks.remove(conditionBlock);
for (BlockNode block : blocks) {
if (block.getInstructions().isEmpty()
&& !block.contains(AFlag.ADDED_TO_REGION)
&& !RegionUtils.isRegionContainsBlock(body, block)) {
body.add(block);
}
}
}
loopRegion.setBody(body);
}
stack.pop();
insertContinue(loop);
return out;
}
/**
* Select loop exit and construct LoopRegion
*/
private LoopRegion makeLoopRegion(IRegion curRegion, LoopInfo loop, List<BlockNode> exitBlocks) {
for (BlockNode block : exitBlocks) {
if (block.contains(AType.EXC_HANDLER)) {
continue;
}
InsnNode lastInsn = BlockUtils.getLastInsn(block);
if (lastInsn == null || lastInsn.getType() != InsnType.IF) {
continue;
}
List<LoopInfo> loops = block.getAll(AType.LOOP);
if (!loops.isEmpty() && loops.get(0) != loop) {
// skip nested loop condition
continue;
}
boolean exitAtLoopEnd = isExitAtLoopEnd(block, loop);
LoopRegion loopRegion = new LoopRegion(curRegion, loop, block, exitAtLoopEnd);
boolean found;
if (block == loop.getStart() || exitAtLoopEnd
|| BlockUtils.isEmptySimplePath(loop.getStart(), block)) {
found = true;
} else if (block.getPredecessors().contains(loop.getStart())) {
loopRegion.setPreCondition(loop.getStart());
// if we can't merge pre-condition this is not correct header
found = loopRegion.checkPreCondition();
} else {
found = false;
}
if (found) {
List<LoopInfo> list = mth.getAllLoopsForBlock(block);
if (list.size() >= 2) {
// bad condition if successors going out of all loops
boolean allOuter = true;
for (BlockNode outerBlock : block.getCleanSuccessors()) {
List<LoopInfo> outLoopList = mth.getAllLoopsForBlock(outerBlock);
outLoopList.remove(loop);
if (!outLoopList.isEmpty()) {
// goes to outer loop
allOuter = false;
break;
}
}
if (allOuter) {
found = false;
}
}
}
if (found && !checkLoopExits(loop, block)) {
found = false;
}
if (found) {
return loopRegion;
}
}
// no exit found => endless loop
return null;
}
private static boolean isExitAtLoopEnd(BlockNode exit, LoopInfo loop) {
BlockNode loopEnd = loop.getEnd();
if (exit == loopEnd) {
return true;
}
BlockNode loopStart = loop.getStart();
if (loopStart.getInstructions().isEmpty() && ListUtils.isSingleElement(loopStart.getSuccessors(), exit)) {
return false;
}
return loopEnd.getInstructions().isEmpty() && ListUtils.isSingleElement(loopEnd.getPredecessors(), exit);
}
private boolean checkLoopExits(LoopInfo loop, BlockNode mainExitBlock) {
List<Edge> exitEdges = loop.getExitEdges();
if (exitEdges.size() < 2) {
return true;
}
Optional<Edge> mainEdgeOpt = exitEdges.stream().filter(edge -> edge.getSource() == mainExitBlock).findFirst();
if (mainEdgeOpt.isEmpty()) {
throw new JadxRuntimeException("Not found exit edge by exit block: " + mainExitBlock);
}
Edge mainExitEdge = mainEdgeOpt.get();
BlockNode mainOutBlock = mainExitEdge.getTarget();
for (Edge exitEdge : exitEdges) {
if (exitEdge != mainExitEdge) {
// all exit paths must be same or don't cross (will be inside loop)
BlockNode exitBlock = exitEdge.getTarget();
if (!BlockUtils.isEqualPaths(mainOutBlock, exitBlock)) {
BlockNode crossBlock = BlockUtils.getPathCross(mth, mainOutBlock, exitBlock);
if (crossBlock != null) {
return false;
}
}
}
}
return true;
}
private BlockNode makeEndlessLoop(IRegion curRegion, RegionStack stack, LoopInfo loop, BlockNode loopStart) {
LoopRegion loopRegion = new LoopRegion(curRegion, loop, null, false);
curRegion.getSubBlocks().add(loopRegion);
loopStart.remove(AType.LOOP);
regionMaker.clearBlockProcessedState(loopStart);
stack.push(loopRegion);
BlockNode out = null;
// insert 'break' for exits
List<Edge> exitEdges = loop.getExitEdges();
if (exitEdges.size() == 1) {
Edge exitEdge = exitEdges.get(0);
BlockNode exit = exitEdge.getTarget();
if (insertLoopBreak(stack, loop, exit, exitEdge)) {
BlockNode nextBlock = getNextBlock(exit);
if (nextBlock != null) {
stack.addExit(nextBlock);
out = nextBlock;
}
}
} else {
for (Edge exitEdge : exitEdges) {
BlockNode exit = exitEdge.getTarget();
List<BlockNode> blocks = BlockUtils.bitSetToBlocks(mth, exit.getDomFrontier());
for (BlockNode block : blocks) {
if (BlockUtils.isPathExists(exit, block)) {
stack.addExit(block);
insertLoopBreak(stack, loop, block, exitEdge);
out = block;
} else {
insertLoopBreak(stack, loop, exit, exitEdge);
}
}
}
}
Region body = regionMaker.makeRegion(loopStart);
BlockNode loopEnd = loop.getEnd();
if (!RegionUtils.isRegionContainsBlock(body, loopEnd)
&& !loopEnd.contains(AType.EXC_HANDLER)
&& !inExceptionHandlerBlocks(loopEnd)) {
body.getSubBlocks().add(loopEnd);
}
loopRegion.setBody(body);
if (out == null) {
BlockNode next = getNextBlock(loopEnd);
out = RegionUtils.isRegionContainsBlock(body, next) ? null : next;
}
stack.pop();
loopStart.addAttr(AType.LOOP, loop);
return out;
}
private boolean inExceptionHandlerBlocks(BlockNode loopEnd) {
if (mth.getExceptionHandlersCount() == 0) {
return false;
}
for (ExceptionHandler eh : mth.getExceptionHandlers()) {
if (eh.getBlocks().contains(loopEnd)) {
return true;
}
}
return false;
}
private boolean canInsertBreak(BlockNode exit) {
if (BlockUtils.containsExitInsn(exit)) {
return false;
}
List<BlockNode> simplePath = BlockUtils.buildSimplePath(exit);
if (!simplePath.isEmpty()) {
BlockNode lastBlock = simplePath.get(simplePath.size() - 1);
if (lastBlock.isMthExitBlock()
|| lastBlock.isReturnBlock()
|| mth.isPreExitBlock(lastBlock)) {
return false;
}
}
// check if there no outer switch (TODO: very expensive check)
Set<BlockNode> paths = BlockUtils.getAllPathsBlocks(mth.getEnterBlock(), exit);
for (BlockNode block : paths) {
if (BlockUtils.checkLastInsnType(block, InsnType.SWITCH)) {
return false;
}
}
return true;
}
private boolean insertLoopBreak(RegionStack stack, LoopInfo loop, BlockNode loopExit, Edge exitEdge) {
BlockNode exit = exitEdge.getTarget();
Edge insertEdge = null;
boolean confirm = false;
// process special cases:
// 1. jump to outer loop
BlockNode exitEnd = BlockUtils.followEmptyPath(exit);
List<LoopInfo> loops = exitEnd.getAll(AType.LOOP);
for (LoopInfo loopAtEnd : loops) {
if (loopAtEnd != loop && loop.hasParent(loopAtEnd)) {
insertEdge = exitEdge;
confirm = true;
break;
}
}
if (!confirm) {
BlockNode insertBlock = null;
while (exit != null) {
if (insertBlock != null && isPathExists(loopExit, exit)) {
// found cross
if (canInsertBreak(insertBlock)) {
insertEdge = new Edge(insertBlock, insertBlock.getSuccessors().get(0));
confirm = true;
break;
}
return false;
}
insertBlock = exit;
List<BlockNode> cs = exit.getCleanSuccessors();
exit = cs.size() == 1 ? cs.get(0) : null;
}
}
if (!confirm) {
return false;
}
InsnNode breakInsn = new InsnNode(InsnType.BREAK, 0);
breakInsn.addAttr(AType.LOOP, loop);
EdgeInsnAttr.addEdgeInsn(insertEdge, breakInsn);
stack.addExit(exit);
// add label to 'break' if needed
addBreakLabel(exitEdge, exit, breakInsn);
return true;
}
private void addBreakLabel(Edge exitEdge, BlockNode exit, InsnNode breakInsn) {
BlockNode outBlock = BlockUtils.getNextBlock(exitEdge.getTarget());
if (outBlock == null) {
return;
}
List<LoopInfo> exitLoop = mth.getAllLoopsForBlock(outBlock);
if (!exitLoop.isEmpty()) {
return;
}
List<LoopInfo> inLoops = mth.getAllLoopsForBlock(exitEdge.getSource());
if (inLoops.size() < 2) {
return;
}
// search for parent loop
LoopInfo parentLoop = null;
for (LoopInfo loop : inLoops) {
if (loop.getParentLoop() == null) {
parentLoop = loop;
break;
}
}
if (parentLoop == null) {
return;
}
if (parentLoop.getEnd() != exit && !parentLoop.getExitNodes().contains(exit)) {
LoopLabelAttr labelAttr = new LoopLabelAttr(parentLoop);
breakInsn.addAttr(labelAttr);
parentLoop.getStart().addAttr(labelAttr);
}
}
private static void insertContinue(LoopInfo loop) {
BlockNode loopEnd = loop.getEnd();
List<BlockNode> predecessors = loopEnd.getPredecessors();
if (predecessors.size() <= 1) {
return;
}
Set<BlockNode> loopExitNodes = loop.getExitNodes();
for (BlockNode pred : predecessors) {
if (canInsertContinue(pred, predecessors, loopEnd, loopExitNodes)) {
InsnNode cont = new InsnNode(InsnType.CONTINUE, 0);
pred.getInstructions().add(cont);
}
}
}
private static boolean canInsertContinue(BlockNode pred, List<BlockNode> predecessors, BlockNode loopEnd,
Set<BlockNode> loopExitNodes) {
if (!pred.contains(AFlag.SYNTHETIC)
|| BlockUtils.checkLastInsnType(pred, InsnType.CONTINUE)) {
return false;
}
List<BlockNode> preds = pred.getPredecessors();
if (preds.isEmpty()) {
return false;
}
BlockNode codePred = preds.get(0);
if (codePred.contains(AFlag.ADDED_TO_REGION)) {
return false;
}
if (loopEnd.isDominator(codePred)
|| loopExitNodes.contains(codePred)) {
return false;
}
if (isDominatedOnBlocks(codePred, predecessors)) {
return false;
}
boolean gotoExit = false;
for (BlockNode exit : loopExitNodes) {
if (BlockUtils.isPathExists(codePred, exit)) {
gotoExit = true;
break;
}
}
return gotoExit;
}
private static boolean isDominatedOnBlocks(BlockNode dom, List<BlockNode> blocks) {
for (BlockNode node : blocks) {
if (!node.isDominator(dom)) {
return false;
}
}
return true;
}
}
@@ -0,0 +1,168 @@
package jadx.core.dex.visitors.regions.maker;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.Objects;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.EdgeInsnAttr;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.SwitchInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnContainer;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.exceptions.JadxOverflowException;
import static jadx.core.utils.BlockUtils.getNextBlock;
public class RegionMaker {
private final MethodNode mth;
private final RegionStack stack;
private final IfRegionMaker ifMaker;
private final LoopRegionMaker loopMaker;
private final BitSet processedBlocks;
private final int regionsLimit;
private int regionsCount;
public RegionMaker(MethodNode mth) {
this.mth = mth;
this.stack = new RegionStack(mth);
this.ifMaker = new IfRegionMaker(mth, this);
this.loopMaker = new LoopRegionMaker(mth, this, ifMaker);
int blocksCount = mth.getBasicBlocks().size();
this.processedBlocks = new BitSet(blocksCount);
this.regionsLimit = blocksCount * 100;
}
public Region makeMthRegion() {
return makeRegion(mth.getEnterBlock());
}
Region makeRegion(BlockNode startBlock) {
Objects.requireNonNull(startBlock);
Region region = new Region(stack.peekRegion());
if (stack.containsExit(startBlock)) {
insertEdgeInsns(region, startBlock);
return region;
}
int startBlockId = startBlock.getId();
if (processedBlocks.get(startBlockId)) {
mth.addWarn("Removed duplicated region for block: " + startBlock + ' ' + startBlock.getAttributesString());
return region;
}
processedBlocks.set(startBlockId);
BlockNode next = startBlock;
while (next != null) {
next = traverse(region, next);
regionsCount++;
if (regionsCount > regionsLimit) {
throw new JadxOverflowException("Regions count limit reached");
}
}
return region;
}
/**
* Recursively traverse all blocks from 'block' until block from 'exits'
*/
private BlockNode traverse(IRegion r, BlockNode block) {
if (block.contains(AFlag.MTH_EXIT_BLOCK)) {
return null;
}
BlockNode next = null;
boolean processed = false;
List<LoopInfo> loops = block.getAll(AType.LOOP);
int loopCount = loops.size();
if (loopCount != 0 && block.contains(AFlag.LOOP_START)) {
if (loopCount == 1) {
next = loopMaker.process(r, loops.get(0), stack);
processed = true;
} else {
for (LoopInfo loop : loops) {
if (loop.getStart() == block) {
next = loopMaker.process(r, loop, stack);
processed = true;
break;
}
}
}
}
InsnNode insn = BlockUtils.getLastInsn(block);
if (!processed && insn != null) {
switch (insn.getType()) {
case IF:
next = ifMaker.process(r, block, (IfNode) insn, stack);
processed = true;
break;
case SWITCH:
SwitchRegionMaker switchMaker = new SwitchRegionMaker(mth, this);
next = switchMaker.process(r, block, (SwitchInsn) insn, stack);
processed = true;
break;
case MONITOR_ENTER:
SynchronizedRegionMaker syncMaker = new SynchronizedRegionMaker(mth, this);
next = syncMaker.process(r, block, insn, stack);
processed = true;
break;
}
}
if (!processed) {
r.getSubBlocks().add(block);
next = getNextBlock(block);
}
if (next != null && !stack.containsExit(block) && !stack.containsExit(next)) {
return next;
}
return null;
}
private void insertEdgeInsns(Region region, BlockNode exitBlock) {
List<EdgeInsnAttr> edgeInsns = exitBlock.getAll(AType.EDGE_INSN);
if (edgeInsns.isEmpty()) {
return;
}
List<InsnNode> insns = new ArrayList<>(edgeInsns.size());
addOneInsnOfType(insns, edgeInsns, InsnType.BREAK);
addOneInsnOfType(insns, edgeInsns, InsnType.CONTINUE);
region.add(new InsnContainer(insns));
}
private void addOneInsnOfType(List<InsnNode> insns, List<EdgeInsnAttr> edgeInsns, InsnType insnType) {
for (EdgeInsnAttr edgeInsn : edgeInsns) {
InsnNode insn = edgeInsn.getInsn();
if (insn.getType() == insnType) {
insns.add(insn);
return;
}
}
}
RegionStack getStack() {
return stack;
}
boolean isProcessed(BlockNode block) {
return processedBlocks.get(block.getId());
}
void clearBlockProcessedState(BlockNode block) {
processedBlocks.clear(block.getId());
}
}
@@ -1,4 +1,4 @@
package jadx.core.dex.visitors.regions;
package jadx.core.dex.visitors.regions.maker;
import java.util.ArrayDeque;
import java.util.Collection;
@@ -31,7 +31,7 @@ final class RegionStack {
IRegion region;
public State() {
exits = new HashSet<>(4);
exits = new HashSet<>();
}
private State(State c, IRegion region) {
@@ -113,6 +113,12 @@ final class RegionStack {
return stack.size();
}
public RegionStack clear() {
stack.clear();
curState = new State();
return this;
}
@Override
public String toString() {
return "Region stack size: " + size() + ", last: " + curState;
@@ -0,0 +1,288 @@
package jadx.core.dex.visitors.regions.maker;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.attributes.nodes.RegionRefAttr;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.SwitchInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.RegionUtils;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxRuntimeException;
final class SwitchRegionMaker {
private final MethodNode mth;
private final RegionMaker regionMaker;
SwitchRegionMaker(MethodNode mth, RegionMaker regionMaker) {
this.mth = mth;
this.regionMaker = regionMaker;
}
BlockNode process(IRegion currentRegion, BlockNode block, SwitchInsn insn, RegionStack stack) {
// map case blocks to keys
int len = insn.getTargets().length;
Map<BlockNode, List<Object>> blocksMap = new LinkedHashMap<>(len);
BlockNode[] targetBlocksArr = insn.getTargetBlocks();
for (int i = 0; i < len; i++) {
List<Object> keys = blocksMap.computeIfAbsent(targetBlocksArr[i], k -> new ArrayList<>(2));
keys.add(insn.getKey(i));
}
BlockNode defCase = insn.getDefTargetBlock();
if (defCase != null) {
List<Object> keys = blocksMap.computeIfAbsent(defCase, k -> new ArrayList<>(1));
keys.add(SwitchRegion.DEFAULT_CASE_KEY);
}
SwitchRegion sw = new SwitchRegion(currentRegion, block);
insn.addAttr(new RegionRefAttr(sw));
currentRegion.getSubBlocks().add(sw);
stack.push(sw);
BlockNode out = calcSwitchOut(block, stack);
stack.addExit(out);
processFallThroughCases(sw, out, stack, blocksMap);
removeEmptyCases(insn, sw, defCase);
stack.pop();
return out;
}
private void processFallThroughCases(SwitchRegion sw, @Nullable BlockNode out,
RegionStack stack, Map<BlockNode, List<Object>> blocksMap) {
Map<BlockNode, BlockNode> fallThroughCases = new LinkedHashMap<>();
if (out != null) {
// detect fallthrough cases
BitSet caseBlocks = BlockUtils.blocksToBitSet(mth, blocksMap.keySet());
caseBlocks.clear(out.getId());
for (BlockNode successor : sw.getHeader().getCleanSuccessors()) {
BitSet df = successor.getDomFrontier();
if (df.intersects(caseBlocks)) {
BlockNode fallThroughBlock = getOneIntersectionBlock(out, caseBlocks, df);
fallThroughCases.put(successor, fallThroughBlock);
}
}
// check fallthrough cases order
if (!fallThroughCases.isEmpty() && isBadCasesOrder(blocksMap, fallThroughCases)) {
Map<BlockNode, List<Object>> newBlocksMap = reOrderSwitchCases(blocksMap, fallThroughCases);
if (isBadCasesOrder(newBlocksMap, fallThroughCases)) {
mth.addWarnComment("Can't fix incorrect switch cases order, some code will duplicate");
fallThroughCases.clear();
} else {
blocksMap = newBlocksMap;
}
}
}
for (Map.Entry<BlockNode, List<Object>> entry : blocksMap.entrySet()) {
List<Object> keysList = entry.getValue();
BlockNode caseBlock = entry.getKey();
if (stack.containsExit(caseBlock)) {
sw.addCase(keysList, new Region(stack.peekRegion()));
} else {
BlockNode next = fallThroughCases.get(caseBlock);
stack.addExit(next);
Region caseRegion = regionMaker.makeRegion(caseBlock);
stack.removeExit(next);
if (next != null) {
next.add(AFlag.FALL_THROUGH);
caseRegion.add(AFlag.FALL_THROUGH);
}
sw.addCase(keysList, caseRegion);
// 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor
}
}
}
@Nullable
private BlockNode getOneIntersectionBlock(BlockNode out, BitSet caseBlocks, BitSet fallThroughSet) {
BitSet caseExits = BlockUtils.copyBlocksBitSet(mth, fallThroughSet);
caseExits.clear(out.getId());
caseExits.and(caseBlocks);
return BlockUtils.bitSetToOneBlock(mth, caseExits);
}
private @Nullable BlockNode calcSwitchOut(BlockNode block, RegionStack stack) {
// union of case blocks dominance frontier
// works if no fallthrough cases and no returns inside switch
BitSet outs = BlockUtils.newBlocksBitSet(mth);
for (BlockNode s : block.getCleanSuccessors()) {
if (s.contains(AFlag.LOOP_END)) {
// loop end dom frontier is loop start, ignore it
continue;
}
outs.or(s.getDomFrontier());
}
outs.clear(block.getId());
outs.clear(mth.getExitBlock().getId());
if (outs.isEmpty()) {
// switch already contains method exit
// add everything, out block not needed
return mth.getExitBlock();
}
BlockNode out = null;
if (outs.cardinality() == 1) {
// single exit
out = BlockUtils.bitSetToOneBlock(mth, outs);
} else {
// several switch exits
// possible 'return', 'continue' or fallthrough in one of the cases
LoopInfo loop = mth.getLoopForBlock(block);
if (loop != null) {
outs.andNot(loop.getStart().getPostDoms());
outs.andNot(loop.getEnd().getPostDoms());
BlockNode loopEnd = loop.getEnd();
if (outs.cardinality() == 2 && outs.get(loopEnd.getId())) {
// insert 'continue' for cases lead to loop end
// expect only 2 exits: loop end and switch out
List<BlockNode> outList = BlockUtils.bitSetToBlocks(mth, outs);
outList.remove(loopEnd);
BlockNode possibleOut = Utils.getOne(outList);
if (possibleOut != null && insertContinueInSwitch(block, possibleOut, loopEnd)) {
outs.clear(loopEnd.getId());
out = possibleOut;
}
}
}
if (out == null) {
BlockNode imPostDom = block.getIPostDom();
if (outs.get(imPostDom.getId())) {
out = imPostDom;
} else {
outs.andNot(block.getPostDoms());
out = BlockUtils.bitSetToOneBlock(mth, outs);
}
}
}
if (out != null && mth.isPreExitBlock(out)) {
// include 'return' or 'throw' in case blocks
out = mth.getExitBlock();
}
BlockNode imPostDom = block.getIPostDom();
if (out != imPostDom && !mth.isPreExitBlock(imPostDom)) {
// stop other paths at common exit
stack.addExit(imPostDom);
}
if (block.getCleanSuccessors().contains(imPostDom)) {
// add exit to stop on empty 'default' block
stack.addExit(imPostDom);
}
if (out == null) {
mth.addWarnComment("Failed to find 'out' block for switch in " + block + ". Please report as an issue.");
// fallback option; should work in most cases
out = block.getIPostDom();
}
if (out != null && regionMaker.isProcessed(out)) {
// 'out' block already processed, prevent endless loop
throw new JadxRuntimeException("Failed to find switch 'out' block (already processed)");
}
return out;
}
/**
* Remove empty case blocks:
* 1. single 'default' case
* 2. filler cases if switch is 'packed' and 'default' case is empty
*/
private void removeEmptyCases(SwitchInsn insn, SwitchRegion sw, BlockNode defCase) {
boolean defaultCaseIsEmpty;
if (defCase == null) {
defaultCaseIsEmpty = true;
} else {
defaultCaseIsEmpty = sw.getCases().stream()
.anyMatch(c -> c.getKeys().contains(SwitchRegion.DEFAULT_CASE_KEY)
&& RegionUtils.isEmpty(c.getContainer()));
}
if (defaultCaseIsEmpty) {
sw.getCases().removeIf(caseInfo -> {
if (RegionUtils.isEmpty(caseInfo.getContainer())) {
List<Object> keys = caseInfo.getKeys();
if (keys.contains(SwitchRegion.DEFAULT_CASE_KEY)) {
return true;
}
if (insn.isPacked()) {
return true;
}
}
return false;
});
}
}
private boolean isBadCasesOrder(Map<BlockNode, List<Object>> blocksMap, Map<BlockNode, BlockNode> fallThroughCases) {
BlockNode nextCaseBlock = null;
for (BlockNode caseBlock : blocksMap.keySet()) {
if (nextCaseBlock != null && !caseBlock.equals(nextCaseBlock)) {
return true;
}
nextCaseBlock = fallThroughCases.get(caseBlock);
}
return nextCaseBlock != null;
}
private Map<BlockNode, List<Object>> reOrderSwitchCases(Map<BlockNode, List<Object>> blocksMap,
Map<BlockNode, BlockNode> fallThroughCases) {
List<BlockNode> list = new ArrayList<>(blocksMap.size());
list.addAll(blocksMap.keySet());
list.sort((a, b) -> {
BlockNode nextA = fallThroughCases.get(a);
if (nextA != null) {
if (b.equals(nextA)) {
return -1;
}
} else if (a.equals(fallThroughCases.get(b))) {
return 1;
}
return 0;
});
Map<BlockNode, List<Object>> newBlocksMap = new LinkedHashMap<>(blocksMap.size());
for (BlockNode key : list) {
newBlocksMap.put(key, blocksMap.get(key));
}
return newBlocksMap;
}
private boolean insertContinueInSwitch(BlockNode switchBlock, BlockNode switchOut, BlockNode loopEnd) {
boolean inserted = false;
for (BlockNode caseBlock : switchBlock.getCleanSuccessors()) {
if (caseBlock.getDomFrontier().get(loopEnd.getId()) && caseBlock != switchOut) {
// search predecessor of loop end on path from this successor
Set<BlockNode> list = new HashSet<>(BlockUtils.collectBlocksDominatedBy(mth, caseBlock, caseBlock));
if (list.contains(switchOut) || switchOut.getPredecessors().stream().anyMatch(list::contains)) {
// 'continue' not needed
} else {
for (BlockNode p : loopEnd.getPredecessors()) {
if (list.contains(p)) {
if (p.isSynthetic()) {
p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0));
inserted = true;
}
break;
}
}
}
}
}
return inserted;
}
}
@@ -0,0 +1,162 @@
package jadx.core.dex.visitors.regions.maker;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.SynchronizedRegion;
import jadx.core.dex.visitors.regions.CleanRegions;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.Utils;
import static jadx.core.utils.BlockUtils.getNextBlock;
import static jadx.core.utils.BlockUtils.isPathExists;
public class SynchronizedRegionMaker {
private static final Logger LOG = LoggerFactory.getLogger(SynchronizedRegionMaker.class);
private final MethodNode mth;
private final RegionMaker regionMaker;
SynchronizedRegionMaker(MethodNode mth, RegionMaker regionMaker) {
this.mth = mth;
this.regionMaker = regionMaker;
}
BlockNode process(IRegion curRegion, BlockNode block, InsnNode insn, RegionStack stack) {
SynchronizedRegion synchRegion = new SynchronizedRegion(curRegion, insn);
synchRegion.getSubBlocks().add(block);
curRegion.getSubBlocks().add(synchRegion);
Set<BlockNode> exits = new LinkedHashSet<>();
Set<BlockNode> cacheSet = new HashSet<>();
traverseMonitorExits(synchRegion, insn.getArg(0), block, exits, cacheSet);
for (InsnNode exitInsn : synchRegion.getExitInsns()) {
BlockNode insnBlock = BlockUtils.getBlockByInsn(mth, exitInsn);
if (insnBlock != null) {
insnBlock.add(AFlag.DONT_GENERATE);
}
// remove arg from MONITOR_EXIT to allow inline in MONITOR_ENTER
exitInsn.removeArg(0);
exitInsn.add(AFlag.DONT_GENERATE);
}
BlockNode body = getNextBlock(block);
if (body == null) {
mth.addWarn("Unexpected end of synchronized block");
return null;
}
BlockNode exit = null;
if (exits.size() == 1) {
exit = getNextBlock(exits.iterator().next());
} else if (exits.size() > 1) {
cacheSet.clear();
exit = traverseMonitorExitsCross(body, exits, cacheSet);
}
stack.push(synchRegion);
if (exit != null) {
stack.addExit(exit);
} else {
for (BlockNode exitBlock : exits) {
// don't add exit blocks which leads to method end blocks ('return', 'throw', etc)
List<BlockNode> list = BlockUtils.buildSimplePath(exitBlock);
if (list.isEmpty() || !BlockUtils.isExitBlock(mth, Utils.last(list))) {
stack.addExit(exitBlock);
// we can still try using this as an exit block to make sure it's visited.
exit = exitBlock;
}
}
}
synchRegion.getSubBlocks().add(regionMaker.makeRegion(body));
stack.pop();
return exit;
}
/**
* Traverse from monitor-enter thru successors and collect blocks contains monitor-exit
*/
private static void traverseMonitorExits(SynchronizedRegion region, InsnArg arg, BlockNode block, Set<BlockNode> exits,
Set<BlockNode> visited) {
visited.add(block);
for (InsnNode insn : block.getInstructions()) {
if (insn.getType() == InsnType.MONITOR_EXIT
&& insn.getArgsCount() > 0
&& insn.getArg(0).equals(arg)) {
exits.add(block);
region.getExitInsns().add(insn);
return;
}
}
for (BlockNode node : block.getSuccessors()) {
if (!visited.contains(node)) {
traverseMonitorExits(region, arg, node, exits, visited);
}
}
}
/**
* Traverse from monitor-enter thru successors and search for exit paths cross
*/
private static BlockNode traverseMonitorExitsCross(BlockNode block, Set<BlockNode> exits, Set<BlockNode> visited) {
visited.add(block);
for (BlockNode node : block.getCleanSuccessors()) {
boolean cross = true;
for (BlockNode exitBlock : exits) {
boolean p = isPathExists(exitBlock, node);
if (!p) {
cross = false;
break;
}
}
if (cross) {
return node;
}
if (!visited.contains(node)) {
BlockNode res = traverseMonitorExitsCross(node, exits, visited);
if (res != null) {
return res;
}
}
}
return null;
}
public static void removeSynchronized(MethodNode mth) {
Region startRegion = mth.getRegion();
List<IContainer> subBlocks = startRegion.getSubBlocks();
if (!subBlocks.isEmpty() && subBlocks.get(0) instanceof SynchronizedRegion) {
SynchronizedRegion synchRegion = (SynchronizedRegion) subBlocks.get(0);
InsnNode synchInsn = synchRegion.getEnterInsn();
if (!synchInsn.getArg(0).isThis()) {
LOG.warn("In synchronized method {}, top region not synchronized by 'this' {}", mth, synchInsn);
return;
}
// replace synchronized block with inner region
startRegion.getSubBlocks().set(0, synchRegion.getRegion());
// remove 'monitor-enter' instruction
InsnRemover.remove(mth, synchInsn);
// remove 'monitor-exit' instruction
for (InsnNode exit : synchRegion.getExitInsns()) {
InsnRemover.remove(mth, exit);
}
// run region cleaner again
CleanRegions.process(mth);
// assume that CodeShrinker will be run after this
}
}
}
@@ -1202,4 +1202,48 @@ public class BlockUtils {
}
return block.get(AType.EXC_CATCH);
}
public static boolean isEqualPaths(BlockNode b1, BlockNode b2) {
if (b1 == b2) {
return true;
}
if (b1 == null || b2 == null) {
return false;
}
return isEqualReturnBlocks(b1, b2) || isEmptySyntheticPath(b1, b2);
}
private static boolean isEmptySyntheticPath(BlockNode b1, BlockNode b2) {
BlockNode n1 = followEmptyPath(b1);
BlockNode n2 = followEmptyPath(b2);
return n1 == n2 || isEqualReturnBlocks(n1, n2);
}
public static boolean isEqualReturnBlocks(BlockNode b1, BlockNode b2) {
if (!b1.isReturnBlock() || !b2.isReturnBlock()) {
return false;
}
List<InsnNode> b1Insns = b1.getInstructions();
List<InsnNode> b2Insns = b2.getInstructions();
if (b1Insns.size() != 1 || b2Insns.size() != 1) {
return false;
}
InsnNode i1 = b1Insns.get(0);
InsnNode i2 = b2Insns.get(0);
if (i1.getArgsCount() != i2.getArgsCount()) {
return false;
}
if (i1.getArgsCount() == 0) {
return true;
}
InsnArg firstArg = i1.getArg(0);
InsnArg secondArg = i2.getArg(0);
if (firstArg.isSameConst(secondArg)) {
return true;
}
if (i1.getSourceLine() != i2.getSourceLine()) {
return false;
}
return firstArg.equals(secondArg);
}
}