From 2b7d7ce2cfa5a68289b0d38feb33319fe4fdce3b Mon Sep 17 00:00:00 2001 From: Skylot Date: Sat, 31 Oct 2020 15:59:43 +0000 Subject: [PATCH] fix: additional casts at use place to help type inference (#1002) --- .../java/jadx/core/dex/attributes/AFlag.java | 2 +- .../core/dex/instructions/args/InsnArg.java | 4 + .../jadx/core/dex/visitors/EnumVisitor.java | 3 +- .../core/dex/visitors/SimplifyVisitor.java | 3 +- .../typeinference/TypeInferenceVisitor.java | 59 ++++++++-- .../main/java/jadx/core/utils/BlockUtils.java | 26 +++- .../main/java/jadx/core/utils/DebugUtils.java | 4 + .../integration/types/TestTypeResolver16.java | 35 ++++++ .../test/smali/types/TestTypeResolver16.smali | 111 ++++++++++++++++++ 9 files changed, 227 insertions(+), 20 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/types/TestTypeResolver16.java create mode 100644 jadx-core/src/test/smali/types/TestTypeResolver16.smali diff --git a/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java b/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java index 5f4bcf063..183255afa 100644 --- a/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java +++ b/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java @@ -67,7 +67,7 @@ public enum AFlag { */ EXPLICIT_PRIMITIVE_TYPE, EXPLICIT_CAST, - SOFT_CAST, // synthetic cast to help type inference + SOFT_CAST, // synthetic cast to help type inference (allow unchecked casts for generics) INCONSISTENT_CODE, // warning about incorrect decompilation diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/args/InsnArg.java b/jadx-core/src/main/java/jadx/core/dex/instructions/args/InsnArg.java index 2c563698c..8794f2f0e 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/args/InsnArg.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/args/InsnArg.java @@ -206,6 +206,10 @@ public abstract class InsnArg extends Typed { return arg; } + public boolean isZeroLiteral() { + return isLiteral() && (((LiteralArg) this)).getLiteral() == 0; + } + public boolean isThis() { return contains(AFlag.THIS); } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/EnumVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/EnumVisitor.java index 6a5873f55..5bb37a77a 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/EnumVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/EnumVisitor.java @@ -29,7 +29,6 @@ import jadx.core.dex.instructions.InvokeNode; import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnWrapArg; -import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.mods.ConstructorInsn; @@ -209,7 +208,7 @@ public class EnumVisitor extends AbstractVisitor { case NEW_ARRAY: InsnArg arg = wrappedInsn.getArg(0); - if (arg.isLiteral() && ((LiteralArg) arg).getLiteral() == 0) { + if (arg.isZeroLiteral()) { // empty enum return Collections.emptyList(); } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/SimplifyVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/SimplifyVisitor.java index 8b459f587..8a1b18373 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/SimplifyVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/SimplifyVisitor.java @@ -260,8 +260,7 @@ public class SimplifyVisitor extends AbstractVisitor { if (f.isInsnWrap()) { InsnNode wi = ((InsnWrapArg) f).getWrapInsn(); if (wi.getType() == InsnType.CMP_L || wi.getType() == InsnType.CMP_G) { - if (insn.getArg(1).isLiteral() - && ((LiteralArg) insn.getArg(1)).getLiteral() == 0) { + if (insn.getArg(1).isZeroLiteral()) { insn.changeCondition(insn.getOp(), wi.getArg(0), wi.getArg(1)); } else { LOG.warn("TODO: cmp {}", insn); diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java index dfa53470b..214a5c6fe 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java @@ -10,6 +10,7 @@ import java.util.Optional; import java.util.Set; import java.util.function.Function; +import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -332,6 +333,10 @@ public final class TypeInferenceVisitor extends AbstractVisitor { return invokeUseBound; } } + if (insn.getType() == InsnType.CHECK_CAST && insn.contains(AFlag.SOFT_CAST)) { + // ignore + return null; + } return new TypeBoundConst(BoundEnum.USE, regArg.getInitType(), regArg); } @@ -499,20 +504,30 @@ public final class TypeInferenceVisitor extends AbstractVisitor { if (insertAssignCast(mth, var, boundType)) { return 1; } - // TODO: check if use casts are needed - return 0; + return insertUseCasts(mth, var); } } return 0; } + private int insertUseCasts(MethodNode mth, SSAVar var) { + List useList = var.getUseList(); + if (useList.isEmpty()) { + return 0; + } + int useCasts = 0; + for (RegisterArg useReg : new ArrayList<>(useList)) { + if (insertSoftUseCast(mth, useReg)) { + useCasts++; + } + } + return useCasts; + } + private boolean insertAssignCast(MethodNode mth, SSAVar var, ArgType castType) { RegisterArg assignArg = var.getAssign(); InsnNode assignInsn = assignArg.getParentInsn(); - if (assignInsn == null) { - return false; - } - if (assignInsn.getType() == InsnType.PHI) { + if (assignInsn == null || assignInsn.getType() == InsnType.PHI) { return false; } BlockNode assignBlock = BlockUtils.getBlockByInsn(mth, assignInsn); @@ -521,14 +536,38 @@ public final class TypeInferenceVisitor extends AbstractVisitor { } RegisterArg newAssignArg = assignArg.duplicateWithNewSSAVar(mth); assignInsn.setResult(newAssignArg); + IndexInsnNode castInsn = makeSoftCastInsn(assignArg, newAssignArg, castType); + return BlockUtils.insertAfterInsn(assignBlock, assignInsn, castInsn); + } + private boolean insertSoftUseCast(MethodNode mth, RegisterArg useArg) { + InsnNode useInsn = useArg.getParentInsn(); + if (useInsn == null || useInsn.getType() == InsnType.PHI) { + return false; + } + if (useInsn.getType() == InsnType.IF && useInsn.getArg(1).isZeroLiteral()) { + // cast not needed if compare with null + return false; + } + BlockNode useBlock = BlockUtils.getBlockByInsn(mth, useInsn); + if (useBlock == null) { + return false; + } + RegisterArg newUseArg = useArg.duplicateWithNewSSAVar(mth); + useInsn.replaceArg(useArg, newUseArg); + + IndexInsnNode castInsn = makeSoftCastInsn(newUseArg, useArg, useArg.getInitType()); + return BlockUtils.insertBeforeInsn(useBlock, useInsn, castInsn); + } + + @NotNull + private IndexInsnNode makeSoftCastInsn(RegisterArg result, RegisterArg arg, ArgType castType) { IndexInsnNode castInsn = new IndexInsnNode(InsnType.CHECK_CAST, castType, 1); - castInsn.setResult(assignArg.duplicate()); - castInsn.addArg(newAssignArg.duplicate()); + castInsn.setResult(result.duplicate()); + castInsn.addArg(arg.duplicate()); castInsn.add(AFlag.SOFT_CAST); castInsn.add(AFlag.SYNTHETIC); - - return BlockUtils.insertAfterInsn(assignBlock, assignInsn, castInsn); + return castInsn; } private boolean trySplitConstInsns(MethodNode mth) { diff --git a/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java b/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java index 154741c72..122587a92 100644 --- a/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java @@ -670,17 +670,33 @@ public class BlockUtils { return false; } + public static boolean insertBeforeInsn(BlockNode block, InsnNode insn, InsnNode newInsn) { + int index = getInsnIndexInBlock(block, insn); + if (index == -1) { + return false; + } + block.getInstructions().add(index, newInsn); + return true; + } + public static boolean insertAfterInsn(BlockNode block, InsnNode insn, InsnNode newInsn) { + int index = getInsnIndexInBlock(block, insn); + if (index == -1) { + return false; + } + block.getInstructions().add(index + 1, newInsn); + return true; + } + + public static int getInsnIndexInBlock(BlockNode block, InsnNode insn) { List instructions = block.getInstructions(); int size = instructions.size(); for (int i = 0; i < size; i++) { - InsnNode instruction = instructions.get(i); - if (instruction == insn) { - instructions.add(i + 1, newInsn); - return true; + if (instructions.get(i) == insn) { + return i; } } - return false; + return -1; } public static boolean replaceInsn(MethodNode mth, InsnNode oldInsn, InsnNode newInsn) { diff --git a/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java b/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java index cb55c9f88..c83b83d97 100644 --- a/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java @@ -169,4 +169,8 @@ public class DebugUtils { LOG.debug(" {}: {}", entry.getKey(), entry.getValue()); } } + + public static void printStackTrace(String label) { + LOG.debug("StackTrace: {}\n{}", label, Utils.getStackTrace(new Exception())); + } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/types/TestTypeResolver16.java b/jadx-core/src/test/java/jadx/tests/integration/types/TestTypeResolver16.java new file mode 100644 index 000000000..63500ab66 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/types/TestTypeResolver16.java @@ -0,0 +1,35 @@ +package jadx.tests.integration.types; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.SmaliTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +/** + * Issue 1002 + * Insertion of additional cast (at use place) needed for successful type inference + */ +public class TestTypeResolver16 extends SmaliTest { + // @formatter:off + /* + public final List test(List list, Set set, Function function) { + checkParameterIsNotNull(function, "distinctBy"); + if (set != null) { + List union = list != null ? union(list, set, function) : null; + if (union != null) { + list = union; + } + } + return list != null ? (List) list : emptyList(); + } + */ + // @formatter:on + + @Test + public void test() { + assertThat(getClassNodeFromSmali()) + .code() + .containsOne("(List) list"); + } +} diff --git a/jadx-core/src/test/smali/types/TestTypeResolver16.smali b/jadx-core/src/test/smali/types/TestTypeResolver16.smali new file mode 100644 index 000000000..00e34e5d2 --- /dev/null +++ b/jadx-core/src/test/smali/types/TestTypeResolver16.smali @@ -0,0 +1,111 @@ +.class public Ltypes/TestTypeResolver16; +.super Ljava/lang/Object; + +.method public final test(Ljava/util/List;Ljava/util/Set;Ljava/util/function/Function;)Ljava/util/List; + .locals 1 + .annotation system Ldalvik/annotation/Signature; + value = { + "(", + "Ljava/util/List<", + "+TT;>;", + "Ljava/util/Set<", + "+TT;>;", + "Ljava/util/function/Function<", + "-TT;+TK;>;)", + "Ljava/util/List<", + "TT;>;" + } + .end annotation + + const-string v0, "distinctBy" + + invoke-static {p3, v0}, Ltypes/TestTypeResolver16;->checkParameterIsNotNull(Ljava/lang/Object;Ljava/lang/String;)V + + if-eqz p2, :cond_1 + + if-eqz p1, :cond_0 + + .line 85 + move-object v0, p1 + + check-cast v0, Ljava/util/Collection; + + check-cast p2, Ljava/lang/Iterable; + + invoke-static {v0, p2, p3}, Ltypes/TestTypeResolver16;->union(Ljava/util/Collection;Ljava/lang/Iterable;Ljava/util/function/Function;)Ljava/util/List; + + move-result-object p2 + + goto :goto_0 + + :cond_0 + const/4 p2, 0x0 + + :goto_0 + if-eqz p2, :cond_1 + + move-object p1, p2 + + :cond_1 + if-eqz p1, :cond_2 + + goto :goto_1 + + :cond_2 + invoke-static {}, Ltypes/TestTypeResolver16;->emptyList()Ljava/util/List; + + move-result-object p1 + + :goto_1 + return-object p1 +.end method + + +.method public static final union(Ljava/util/Collection;Ljava/lang/Iterable;Ljava/util/function/Function;)Ljava/util/List; + .locals 4 + .annotation system Ldalvik/annotation/Signature; + value = { + "(", + "Ljava/util/Collection<", + "+TT;>;", + "Ljava/lang/Iterable<", + "+TT;>;", + "Ljava/util/function/Function<", + "-TT;+TK;>;)", + "Ljava/util/List<", + "TT;>;" + } + .end annotation + + const/4 v0, 0x0 + return-object v0 +.end method + +.method public static checkParameterIsNotNull(Ljava/lang/Object;Ljava/lang/String;)V + .locals 0 + return-void +.end method + +.method public static final emptyList()Ljava/util/List; + .locals 1 + .annotation system Ldalvik/annotation/Signature; + value = { + "()", + "Ljava/util/List<", + "TT;>;" + } + .end annotation + + const/4 v0, 0x0 + return-object v0 +.end method