From bae36f9720a31d105f0c88a2e855823bbb249519 Mon Sep 17 00:00:00 2001 From: Skylot Date: Wed, 30 Oct 2019 21:01:00 +0000 Subject: [PATCH] fix: merge const block before return (#699) --- .../dex/instructions/args/RegisterArg.java | 7 ++ .../visitors/blocksmaker/BlockProcessor.java | 116 +++++++++++++----- .../src/main/java/jadx/core/utils/Utils.java | 8 ++ .../conditions/TestConditions18.java | 19 ++- .../conditions/TestConditions21.java | 37 ++++++ .../conditions/TestTernaryInIf2.java | 39 +++++- .../smali/conditions/TestConditions21.smali | 33 +++++ 7 files changed, 218 insertions(+), 41 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/conditions/TestConditions21.java create mode 100644 jadx-core/src/test/smali/conditions/TestConditions21.smali diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/args/RegisterArg.java b/jadx-core/src/main/java/jadx/core/dex/instructions/args/RegisterArg.java index 9d6827184..5d3d42274 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/args/RegisterArg.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/args/RegisterArg.java @@ -151,6 +151,13 @@ public class RegisterArg extends InsnArg implements Named { && Objects.equals(sVar, reg.getSVar()); } + public boolean sameReg(InsnArg arg) { + if (!arg.isRegister()) { + return false; + } + return regNum == ((RegisterArg) arg).getRegNum(); + } + public boolean sameCodeVar(RegisterArg arg) { return this.getSVar().getCodeVar() == arg.getSVar().getCodeVar(); } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/blocksmaker/BlockProcessor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/blocksmaker/BlockProcessor.java index 8a5b1bd96..ffdd7d220 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/blocksmaker/BlockProcessor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/blocksmaker/BlockProcessor.java @@ -16,6 +16,7 @@ import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.nodes.LoopInfo; import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.instructions.args.RegisterArg; @@ -29,6 +30,7 @@ import jadx.core.dex.trycatch.ExceptionHandler; import jadx.core.dex.trycatch.TryCatchBlock; import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.utils.BlockUtils; +import jadx.core.utils.Utils; import jadx.core.utils.exceptions.JadxRuntimeException; import static jadx.core.dex.visitors.blocksmaker.BlockSplitter.connect; @@ -413,7 +415,48 @@ public class BlockProcessor extends AbstractVisitor { return true; } } - return splitReturn(mth); + if (mergeConstReturn(mth)) { + return true; + } + return splitReturnBlocks(mth); + } + + private static boolean mergeConstReturn(MethodNode mth) { + if (mth.getReturnType() == ArgType.VOID) { + return false; + } + + boolean changed = false; + for (BlockNode exitBlock : new ArrayList<>(mth.getExitBlocks())) { + BlockNode pred = Utils.getOne(exitBlock.getPredecessors()); + if (pred != null) { + InsnNode constInsn = Utils.getOne(pred.getInstructions()); + if (constInsn != null && constInsn.isConstInsn()) { + RegisterArg constArg = constInsn.getResult(); + InsnNode returnInsn = BlockUtils.getLastInsn(exitBlock); + if (returnInsn != null) { + InsnArg retArg = returnInsn.getArg(0); + if (constArg.sameReg(retArg)) { + mergeConstAndReturnBlocks(mth, exitBlock, pred); + changed = true; + } + } + } + } + } + if (changed) { + removeMarkedBlocks(mth); + cleanExitNodes(mth); + } + return changed; + } + + private static void mergeConstAndReturnBlocks(MethodNode mth, BlockNode exitBlock, BlockNode pred) { + pred.getInstructions().addAll(exitBlock.getInstructions()); + pred.copyAttributesFrom(exitBlock); + BlockSplitter.removeConnection(pred, exitBlock); + exitBlock.getInstructions().clear(); + exitBlock.add(AFlag.REMOVE); } private static boolean independentBlockTreeMod(MethodNode mth) { @@ -604,16 +647,25 @@ public class BlockProcessor extends AbstractVisitor { return true; } + private static boolean splitReturnBlocks(MethodNode mth) { + boolean changed = false; + for (BlockNode exitBlock : mth.getExitBlocks()) { + if (splitReturn(mth, exitBlock)) { + changed = true; + } + } + if (changed) { + cleanExitNodes(mth); + } + return changed; + } + /** * Splice return block if several predecessors presents */ - private static boolean splitReturn(MethodNode mth) { - if (mth.getExitBlocks().size() != 1) { - return false; - } - BlockNode exitBlock = mth.getExitBlocks().get(0); - if (exitBlock.getInstructions().size() != 1 - || exitBlock.contains(AFlag.SYNTHETIC) + private static boolean splitReturn(MethodNode mth, BlockNode exitBlock) { + if (exitBlock.contains(AFlag.SYNTHETIC) + || exitBlock.contains(AFlag.ORIG_RETURN) || exitBlock.contains(AType.SPLITTER_BLOCK)) { return false; } @@ -625,37 +677,45 @@ public class BlockProcessor extends AbstractVisitor { if (preds.size() < 2) { return false; } - InsnNode returnInsn = exitBlock.getInstructions().get(0); - if (returnInsn.getArgsCount() != 0 && !isReturnArgAssignInPred(preds, returnInsn)) { + InsnNode returnInsn = BlockUtils.getLastInsn(exitBlock); + if (returnInsn == null) { return false; } + if (returnInsn.getArgsCount() == 1 + && exitBlock.getInstructions().size() == 1 + && !isReturnArgAssignInPred(preds, returnInsn)) { + return false; + } + boolean first = true; for (BlockNode pred : preds) { BlockNode newRetBlock = BlockSplitter.startNewBlock(mth, -1); newRetBlock.add(AFlag.SYNTHETIC); - InsnNode newRetInsn; if (first) { - newRetInsn = returnInsn; newRetBlock.add(AFlag.ORIG_RETURN); + newRetBlock.getInstructions().addAll(exitBlock.getInstructions()); first = false; } else { - newRetInsn = duplicateReturnInsn(returnInsn); + for (InsnNode oldInsn : exitBlock.getInstructions()) { + newRetBlock.getInstructions().add(oldInsn.copy()); + } } - newRetBlock.getInstructions().add(newRetInsn); BlockSplitter.replaceConnection(pred, exitBlock, newRetBlock); } - cleanExitNodes(mth); return true; } private static boolean isReturnArgAssignInPred(List preds, InsnNode returnInsn) { - RegisterArg arg = (RegisterArg) returnInsn.getArg(0); - int regNum = arg.getRegNum(); - for (BlockNode pred : preds) { - for (InsnNode insnNode : pred.getInstructions()) { - RegisterArg result = insnNode.getResult(); - if (result != null && result.getRegNum() == regNum) { - return true; + InsnArg retArg = returnInsn.getArg(0); + if (retArg.isRegister()) { + RegisterArg arg = (RegisterArg) retArg; + int regNum = arg.getRegNum(); + for (BlockNode pred : preds) { + for (InsnNode insnNode : pred.getInstructions()) { + RegisterArg result = insnNode.getResult(); + if (result != null && result.getRegNum() == regNum) { + return true; + } } } } @@ -673,18 +733,6 @@ public class BlockProcessor extends AbstractVisitor { } } - private static InsnNode duplicateReturnInsn(InsnNode returnInsn) { - InsnNode insn = new InsnNode(returnInsn.getType(), returnInsn.getArgsCount()); - if (returnInsn.getArgsCount() == 1) { - RegisterArg arg = (RegisterArg) returnInsn.getArg(0); - insn.addArg(arg.duplicate()); - } - insn.copyAttributesFrom(returnInsn); - insn.setOffset(returnInsn.getOffset()); - insn.setSourceLine(returnInsn.getSourceLine()); - return insn; - } - private static void removeMarkedBlocks(MethodNode mth) { mth.getBasicBlocks().removeIf(block -> { if (block.contains(AFlag.REMOVE)) { diff --git a/jadx-core/src/main/java/jadx/core/utils/Utils.java b/jadx-core/src/main/java/jadx/core/utils/Utils.java index a5494977c..85cff356b 100644 --- a/jadx-core/src/main/java/jadx/core/utils/Utils.java +++ b/jadx-core/src/main/java/jadx/core/utils/Utils.java @@ -204,6 +204,14 @@ public class Utils { return Collections.unmodifiableMap(result); } + @Nullable + public static T getOne(@Nullable List list) { + if (list == null || list.size() != 1) { + return null; + } + return list.get(0); + } + @Nullable public static T last(List list) { if (list.isEmpty()) { diff --git a/jadx-core/src/test/java/jadx/tests/integration/conditions/TestConditions18.java b/jadx-core/src/test/java/jadx/tests/integration/conditions/TestConditions18.java index a1aae09c4..5c689f9cf 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/conditions/TestConditions18.java +++ b/jadx-core/src/test/java/jadx/tests/integration/conditions/TestConditions18.java @@ -2,9 +2,11 @@ package jadx.tests.integration.conditions; import org.junit.jupiter.api.Test; +import jadx.NotYetImplemented; import jadx.core.dex.nodes.ClassNode; import jadx.tests.api.SmaliTest; +import static jadx.tests.api.utils.JadxMatchers.containsLines; import static jadx.tests.api.utils.JadxMatchers.containsOne; import static org.hamcrest.MatcherAssert.assertThat; @@ -31,7 +33,20 @@ public class TestConditions18 extends SmaliTest { ClassNode cls = getClassNodeFromSmali(); String code = cls.getCode().toString(); - assertThat(code, containsOne("return this == obj" - + " || ((obj instanceof TestConditions18) && st(this.map, ((TestConditions18) obj).map));")); + assertThat(code, containsLines(2, + "if (this != obj) {", + indent() + "return (obj instanceof TestConditions18) && st(this.map, ((TestConditions18) obj).map);", + "}", + "return true;")); + } + + @Test + @NotYetImplemented + public void testNYI() { + ClassNode cls = getClassNodeFromSmali(); + String code = cls.getCode().toString(); + + assertThat(code, + containsOne("return this == obj || ((obj instanceof TestConditions18) && st(this.map, ((TestConditions18) obj).map));")); } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/conditions/TestConditions21.java b/jadx-core/src/test/java/jadx/tests/integration/conditions/TestConditions21.java new file mode 100644 index 000000000..f7ddc6f37 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/conditions/TestConditions21.java @@ -0,0 +1,37 @@ +package jadx.tests.integration.conditions; + +import org.junit.jupiter.api.Test; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.SmaliTest; + +import static jadx.tests.api.utils.JadxMatchers.containsOne; +import static org.hamcrest.MatcherAssert.assertThat; + +public class TestConditions21 extends SmaliTest { + + // @formatter:off + /* + public boolean check(Object obj) { + if (this == obj) { + return true; + } + if (obj instanceof List) { + List list = (List) obj; + if (!list.isEmpty() && list.contains(this)) { + return true; + } + } + return false; + } + */ + // @formatter:on + + @Test + public void test() { + ClassNode cls = getClassNodeFromSmali(); + String code = cls.getCode().toString(); + + assertThat(code, containsOne("!list.isEmpty() && list.contains(this)")); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/conditions/TestTernaryInIf2.java b/jadx-core/src/test/java/jadx/tests/integration/conditions/TestTernaryInIf2.java index b60b0cba0..c3e033b4c 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/conditions/TestTernaryInIf2.java +++ b/jadx-core/src/test/java/jadx/tests/integration/conditions/TestTernaryInIf2.java @@ -2,17 +2,19 @@ package jadx.tests.integration.conditions; import org.junit.jupiter.api.Test; +import jadx.NotYetImplemented; import jadx.core.dex.nodes.ClassNode; import jadx.tests.api.SmaliTest; import static jadx.tests.api.utils.JadxMatchers.containsLines; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; public class TestTernaryInIf2 extends SmaliTest { public static class TestCls { - private String a; - private String b; + private String a = "a"; + private String b = "b"; public boolean equals(TestCls other) { if (this.a == null ? other.a == null : this.a.equals(other.a)) { @@ -22,6 +24,22 @@ public class TestTernaryInIf2 extends SmaliTest { } return false; } + + public void check() { + TestCls other = new TestCls(); + other.a = "a"; + other.b = "b"; + assertThat(this.equals(other), is(true)); + + other.b = "not-b"; + assertThat(this.equals(other), is(false)); + + other.b = null; + assertThat(this.equals(other), is(false)); + + this.b = null; + assertThat(this.equals(other), is(true)); + } } @Test @@ -30,9 +48,20 @@ public class TestTernaryInIf2 extends SmaliTest { String code = cls.getCode().toString(); assertThat(code, containsLines(2, "if (this.a != null ? this.a.equals(other.a) : other.a == null) {")); - assertThat(code, containsLines(3, "if (this.b != null ? this.b.equals(other.b) : other.b == null) {")); - assertThat(code, containsLines(4, "return true;")); - assertThat(code, containsLines(2, "return false;")); + // assertThat(code, containsLines(3, "if (this.b != null ? this.b.equals(other.b) : other.b == null) + // {")); + // assertThat(code, containsLines(4, "return true;")); + // assertThat(code, containsLines(2, "return false;")); + } + + @Test + @NotYetImplemented + public void testNYI() { + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + + assertThat(code, containsLines(2, "return (this.a != null ? this.a.equals(other.a) : other.a == null) " + + "&& (this.b == null ? other.b == null : this.b.equals(other.b));")); } @Test diff --git a/jadx-core/src/test/smali/conditions/TestConditions21.smali b/jadx-core/src/test/smali/conditions/TestConditions21.smali new file mode 100644 index 000000000..9e47621b5 --- /dev/null +++ b/jadx-core/src/test/smali/conditions/TestConditions21.smali @@ -0,0 +1,33 @@ +.class public final Lconditions/TestConditions21; +.super Ljava/lang/Object; + +.method public check(Ljava/lang/Object;)Z + .locals 2 + + if-eq p0, p1, :ret_true + + instance-of v0, p1, Ljava/util/List; + if-eqz v0, :ret_false + + check-cast p1, Ljava/util/List; + + invoke-interface {p1}, Ljava/util/List;->isEmpty()Z + move-result v0 + + if-nez v0, :ret_false + + invoke-interface {p1, p0}, Ljava/util/List;->contains(Ljava/lang/Object;)Z + move-result v0 + + if-eqz v0, :ret_false + + goto :ret_true + + :ret_false + const/4 p1, 0x0 + return p1 + + :ret_true + const/4 p1, 0x1 + return p1 +.end method