fix: improve try/catch temp edges injection (#2247)

This commit is contained in:
Skylot
2024-08-17 20:45:40 +01:00
parent eee354a3ab
commit 847225a6a9
4 changed files with 104 additions and 16 deletions
@@ -65,7 +65,6 @@ public class BlockSplitter extends AbstractVisitor {
Map<Integer, BlockNode> blocksMap = splitBasicBlocks(mth);
setupConnectionsFromJumps(mth, blocksMap);
initBlocksInTargetNodes(mth);
addTempConnectionsForExcHandlers(mth, blocksMap);
expandMoveMulti(mth);
if (mth.contains(AFlag.RESOLVE_JAVA_JSR)) {
@@ -76,6 +75,8 @@ public class BlockSplitter extends AbstractVisitor {
removeInsns(mth);
removeEmptyDetachedBlocks(mth);
mth.getBasicBlocks().removeIf(BlockSplitter::removeEmptyBlock);
addTempConnectionsForExcHandlers(mth, blocksMap);
setupExitConnections(mth);
mth.unloadInsnArr();
@@ -257,10 +258,13 @@ public class BlockSplitter extends AbstractVisitor {
/**
* Connect exception handlers to the throw block.
* This temporary connection needed to build close to final dominators tree.
* This temporary connection is necessary to build close to a final dominator tree.
* Will be used and removed in {@code jadx.core.dex.visitors.blocks.BlockExceptionHandler}
*/
private static void addTempConnectionsForExcHandlers(MethodNode mth, Map<Integer, BlockNode> blocksMap) {
if (mth.isNoExceptionHandlers()) {
return;
}
for (BlockNode block : mth.getBasicBlocks()) {
for (InsnNode insn : block.getInstructions()) {
CatchAttr catchAttr = insn.get(AType.EXC_CATCH);
@@ -4,10 +4,12 @@ import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.PhiListAttr;
import jadx.core.dex.attributes.nodes.TmpEdgeAttr;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.PhiInsn;
@@ -61,14 +63,14 @@ public class DebugChecks {
}
for (BlockNode block : basicBlocks) {
for (InsnNode insn : block.getInstructions()) {
checkInsn(mth, insn);
checkInsn(mth, block, insn);
}
}
checkSSAVars(mth);
// checkPHI(mth);
}
private static void checkInsn(MethodNode mth, InsnNode insn) {
private static void checkInsn(MethodNode mth, BlockNode block, InsnNode insn) {
if (insn.getResult() != null) {
checkVar(mth, insn, insn.getResult());
}
@@ -77,24 +79,45 @@ public class DebugChecks {
checkVar(mth, insn, (RegisterArg) arg);
} else if (arg.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
checkInsn(mth, wrapInsn);
checkInsn(mth, block, wrapInsn);
}
}
if (insn instanceof TernaryInsn) {
TernaryInsn ternaryInsn = (TernaryInsn) insn;
for (RegisterArg arg : ternaryInsn.getCondition().getRegisterArgs()) {
checkVar(mth, insn, arg);
}
} else if (insn instanceof IfNode) {
IfNode ifNode = (IfNode) insn;
checkBlock(mth, ifNode.getThenBlock());
checkBlock(mth, ifNode.getElseBlock());
switch (insn.getType()) {
case TERNARY:
TernaryInsn ternaryInsn = (TernaryInsn) insn;
for (RegisterArg arg : ternaryInsn.getCondition().getRegisterArgs()) {
checkVar(mth, insn, arg);
}
break;
case IF:
IfNode ifNode = (IfNode) insn;
if (!ifNode.getThenBlock().equals(ifNode.getElseBlock())) {
// exclude temp edges
int branches = (int) block.getSuccessors().stream().filter(b -> !hasTmpEdge(block, b)).count();
if (branches != 2) {
DebugUtils.dumpRaw(mth, "error");
throw new JadxRuntimeException(
"Incorrect if block successors count: " + branches + " (expect 2), block: " + block);
}
}
checkBlock(mth, ifNode.getThenBlock(), () -> "then block in if insn: " + ifNode);
checkBlock(mth, ifNode.getElseBlock(), () -> "else block in if insn: " + ifNode);
break;
}
}
private static void checkBlock(MethodNode mth, BlockNode block) {
private static boolean hasTmpEdge(BlockNode start, BlockNode end) {
TmpEdgeAttr tmpEdgeAttr = end.get(AType.TMP_EDGE);
if (tmpEdgeAttr == null) {
return false;
}
return tmpEdgeAttr.getBlock().equals(start);
}
private static void checkBlock(MethodNode mth, BlockNode block, Supplier<String> source) {
if (!mth.getBasicBlocks().contains(block)) {
throw new JadxRuntimeException("Block not registered in method: " + block);
throw new JadxRuntimeException("Block not registered in method: " + block + " from " + source.get());
}
}
@@ -0,0 +1,17 @@
package jadx.tests.integration.trycatch;
import org.junit.jupiter.api.Test;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestTryCatch10 extends SmaliTest {
@Test
public void test() {
assertThat(getClassNodeFromSmali())
.code()
.countString(3, "return false;");
}
}
@@ -0,0 +1,44 @@
.class public Ltrycatch/TestTryCatch10;
.super Ljava/lang/Object;
.field public static VERSION:I
.method public static test(I)Z
.registers 5
sget v0, Ltrycatch/TestTryCatch10;->VERSION:I
const/16 v1, 0x1d
const/4 v2, 0x0
if-lt v0, v1, :cond_1b
const-string v0, "custom"
invoke-static {v0}, Ltrycatch/TestTryCatch10;->check(Ljava/lang/String;)Z
move-result v0
if-nez v0, :cond_10
goto :goto_1b
:cond_10
:try_start_10
invoke-static {p0}, Ltrycatch/TestTryCatch10;->getVar(I)I
move-result p0
:try_end_18
.catch Ljava/lang/Exception; {:try_start_10 .. :try_end_18} :catch_1b
if-eqz p0, :cond_1b
const/4 v2, 0x1
:catch_1b
:cond_1b
:goto_1b
return v2
.end method
.method public static getVar(I)I
.locals 0
return p0
.end method
.method public static check(Ljava/lang/String;)Z
.locals 1
const/4 v0, 0x0
return v0
.end method