refactor: replace recursion with loop for region traversal

This commit is contained in:
Skylot
2026-04-14 20:57:41 +01:00
parent ccc4164d54
commit 869422b424
7 changed files with 89 additions and 49 deletions
@@ -74,7 +74,9 @@ public final class SwitchRegion extends AbstractRegion implements IBranchRegion
public List<IContainer> getSubBlocks() {
List<IContainer> all = new ArrayList<>(cases.size() + 1);
all.add(header);
all.addAll(getCaseContainers());
for (CaseInfo caseInfo : cases) {
all.add(caseInfo.container);
}
return Collections.unmodifiableList(all);
}
@@ -16,7 +16,7 @@ import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.Utils;
@@ -30,7 +30,7 @@ public class ExceptionHandler {
private BlockNode handlerBlock;
private final List<BlockNode> blocks = new ArrayList<>();
private IContainer handlerRegion;
private IRegion handlerRegion;
private InsnArg arg;
private TryCatchBlockAttr tryBlock;
@@ -122,11 +122,11 @@ public class ExceptionHandler {
blocks.add(node);
}
public IContainer getHandlerRegion() {
public IRegion getHandlerRegion() {
return handlerRegion;
}
public void setHandlerRegion(IContainer handlerRegion) {
public void setHandlerRegion(IRegion handlerRegion) {
this.handlerRegion = handlerRegion;
}
@@ -1,11 +1,19 @@
package jadx.core.dex.visitors.regions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.jetbrains.annotations.Nullable;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnContainer;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.utils.exceptions.JadxOverflowException;
import jadx.core.utils.ListUtils;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class DepthRegionTraversal {
@@ -59,34 +67,61 @@ public class DepthRegionTraversal {
} while (repeat);
}
private static void traverseInternal(MethodNode mth, IRegionVisitor visitor, IContainer container) {
if (container instanceof IBlock) {
visitor.processBlock(mth, (IBlock) container);
} else if (container instanceof IRegion) {
IRegion region = (IRegion) container;
if (visitor.enterRegion(mth, region)) {
region.getSubBlocks().forEach(subCont -> traverseInternal(mth, visitor, subCont));
}
visitor.leaveRegion(mth, region);
}
}
private static final IContainer LEAVE_REGION_MARK = new InsnContainer(Collections.emptyList());
private static boolean traverseIterativeStepInternal(MethodNode mth, IRegionIterativeVisitor visitor, IContainer container) {
if (container instanceof IRegion) {
IRegion region = (IRegion) container;
if (visitor.visitRegion(mth, region)) {
return true;
private static void traverseInternal(MethodNode mth, IRegionVisitor visitor, IContainer startContainer) {
List<IContainer> stack = new ArrayList<>();
List<IRegion> regionLeaveStack = new ArrayList<>();
stack.add(startContainer);
while (true) {
IContainer current = ListUtils.removeLast(stack);
if (current == null) {
return;
}
for (IContainer subCont : region.getSubBlocks()) {
try {
if (traverseIterativeStepInternal(mth, visitor, subCont)) {
return true;
}
} catch (StackOverflowError overflow) {
throw new JadxOverflowException("Region traversal failed: Recursive call in traverseIterativeStepInternal method");
if (current == LEAVE_REGION_MARK) {
IRegion region = ListUtils.removeLast(regionLeaveStack);
visitor.leaveRegion(mth, Objects.requireNonNull(region));
} else if (current instanceof IBlock) {
visitor.processBlock(mth, (IBlock) current);
} else if (current instanceof IRegion) {
IRegion region = (IRegion) current;
boolean visitRegion = visitor.enterRegion(mth, region);
stack.add(LEAVE_REGION_MARK);
regionLeaveStack.add(region);
if (visitRegion) {
addSubBlocksToStack(stack, region);
}
}
}
}
private static void addSubBlocksToStack(List<IContainer> stack, IRegion region) {
List<IContainer> subBlocks = region.getSubBlocks();
// add in reverse order to keep original order during visit
for (int i = subBlocks.size() - 1; i >= 0; i--) {
stack.add(subBlocks.get(i));
}
}
private static boolean traverseIterativeStepInternal(MethodNode mth, IRegionIterativeVisitor visitor, IRegion startRegion) {
List<IRegion> stack = new ArrayList<>();
stack.add(startRegion);
while (true) {
IRegion region = ListUtils.removeLast(stack);
if (region == null) {
return false;
}
if (visitor.visitRegion(mth, region)) {
return true;
}
List<IContainer> subBlocks = region.getSubBlocks();
// add in reverse order to keep original order during visit
for (int i = subBlocks.size() - 1; i >= 0; i--) {
IContainer subBlock = subBlocks.get(i);
if (subBlock instanceof IRegion) {
stack.add((IRegion) subBlock);
}
}
}
return false;
}
}
@@ -61,7 +61,7 @@ public class ReturnVisitor extends AbstractVisitor {
if (mth.getLoopForBlock(block) != null) {
return false;
}
for (IRegion region : regionStack) {
for (IRegion region : getRegionStack()) {
if (region.getClass() == LoopRegion.class) {
return false;
}
@@ -74,7 +74,7 @@ public class ReturnVisitor extends AbstractVisitor {
*/
private boolean noTrailInstructions(BlockNode block) {
IContainer curContainer = block;
for (IRegion region : regionStack) {
for (IRegion region : getRegionStack()) {
// ignore paths on other branches
if (region instanceof IBranchRegion) {
curContainer = region;
@@ -2,14 +2,14 @@ package jadx.core.dex.visitors.regions;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Objects;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.MethodNode;
public abstract class TracedRegionVisitor implements IRegionVisitor {
protected final Deque<IRegion> regionStack = new ArrayDeque<>();
private final Deque<IRegion> regionStack = new ArrayDeque<>();
@Override
public boolean enterRegion(MethodNode mth, IRegion region) {
@@ -19,7 +19,7 @@ public abstract class TracedRegionVisitor implements IRegionVisitor {
@Override
public void processBlock(MethodNode mth, IBlock block) {
IRegion curRegion = regionStack.peek();
IRegion curRegion = Objects.requireNonNull(regionStack.peek());
processBlockTraced(mth, block, curRegion);
}
@@ -29,4 +29,8 @@ public abstract class TracedRegionVisitor implements IRegionVisitor {
public void leaveRegion(MethodNode mth, IRegion region) {
regionStack.pop();
}
public Deque<IRegion> getRegionStack() {
return regionStack;
}
}
@@ -43,11 +43,11 @@ public class TypeUpdateInfo {
}
public @Nullable TypeUpdateRequest pollNextRequest() {
return ListUtils.pollLast(queue);
return ListUtils.removeLast(queue);
}
public @Nullable TypeUpdateRequest pollNextCallback() {
return ListUtils.pollLast(callbackQueue);
return ListUtils.removeLast(callbackQueue);
}
public void requestUpdate(InsnArg arg, ArgType changeType) {
@@ -67,6 +67,13 @@ public class ListUtils {
return list.get(0);
}
public static <T> T firstOrNull(List<T> list) {
if (list == null || list.isEmpty()) {
return null;
}
return list.get(0);
}
public static <T> @Nullable T last(List<T> list) {
if (list == null || list.isEmpty()) {
return null;
@@ -74,7 +81,10 @@ public class ListUtils {
return list.get(list.size() - 1);
}
public static <T> @Nullable T removeLast(List<T> list) {
public static <T> @Nullable T removeLast(@Nullable List<T> list) {
if (list == null) {
return null;
}
int size = list.size();
if (size == 0) {
return null;
@@ -227,15 +237,4 @@ public class ListUtils {
}
return list;
}
public static <T> @Nullable T pollLast(List<T> list) {
if (list == null) {
return null;
}
int size = list.size();
if (size == 0) {
return null;
}
return list.remove(size - 1);
}
}