From 6192ced214dc05f6848ca786d2d38a900f717bbe Mon Sep 17 00:00:00 2001 From: Skylot Date: Mon, 1 Jun 2020 19:59:28 +0100 Subject: [PATCH] fix: improve type inference of type variables in method invoke (#913) --- jadx-core/src/main/java/jadx/core/Consts.java | 1 + .../java/jadx/core/dex/attributes/AType.java | 2 + .../attributes/nodes/MethodTypeVarsAttr.java | 33 +++++ .../core/dex/instructions/args/ArgType.java | 31 ++++ .../java/jadx/core/dex/nodes/ClassNode.java | 11 ++ .../jadx/core/dex/nodes/utils/TypeUtils.java | 39 ++++- .../debuginfo/DebugInfoApplyVisitor.java | 2 +- .../typeinference/TypeInferenceVisitor.java | 77 +++++----- .../visitors/typeinference/TypeSearch.java | 6 +- .../visitors/typeinference/TypeUpdate.java | 135 +++++++++++------- .../typeinference/TypeUpdateInfo.java | 9 +- .../invoke/TestCastInOverloadedInvoke.java | 5 +- .../invoke/TestHierarchyOverloadedInvoke.java | 12 +- .../integration/types/TestGenerics6.java | 70 +++++++++ 14 files changed, 325 insertions(+), 108 deletions(-) create mode 100644 jadx-core/src/main/java/jadx/core/dex/attributes/nodes/MethodTypeVarsAttr.java create mode 100644 jadx-core/src/test/java/jadx/tests/integration/types/TestGenerics6.java diff --git a/jadx-core/src/main/java/jadx/core/Consts.java b/jadx-core/src/main/java/jadx/core/Consts.java index f76c23b93..908047fe0 100644 --- a/jadx-core/src/main/java/jadx/core/Consts.java +++ b/jadx-core/src/main/java/jadx/core/Consts.java @@ -3,6 +3,7 @@ package jadx.core; public class Consts { public static final boolean DEBUG = false; public static final boolean DEBUG_USAGE = false; + public static final boolean DEBUG_TYPE_INFERENCE = false; public static final String CLASS_OBJECT = "java.lang.Object"; public static final String CLASS_STRING = "java.lang.String"; diff --git a/jadx-core/src/main/java/jadx/core/dex/attributes/AType.java b/jadx-core/src/main/java/jadx/core/dex/attributes/AType.java index e9be49206..d94edf0c3 100644 --- a/jadx-core/src/main/java/jadx/core/dex/attributes/AType.java +++ b/jadx-core/src/main/java/jadx/core/dex/attributes/AType.java @@ -21,6 +21,7 @@ import jadx.core.dex.attributes.nodes.LoopInfo; import jadx.core.dex.attributes.nodes.LoopLabelAttr; import jadx.core.dex.attributes.nodes.MethodInlineAttr; import jadx.core.dex.attributes.nodes.MethodOverrideAttr; +import jadx.core.dex.attributes.nodes.MethodTypeVarsAttr; import jadx.core.dex.attributes.nodes.PhiListAttr; import jadx.core.dex.attributes.nodes.RegDebugInfoAttr; import jadx.core.dex.attributes.nodes.RenameReasonAttr; @@ -64,6 +65,7 @@ public class AType { public static final AType ANNOTATION_MTH_PARAMETERS = new AType<>(); public static final AType SKIP_MTH_ARGS = new AType<>(); public static final AType METHOD_OVERRIDE = new AType<>(); + public static final AType METHOD_TYPE_VARS = new AType<>(); // region public static final AType DECLARE_VARIABLES = new AType<>(); diff --git a/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/MethodTypeVarsAttr.java b/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/MethodTypeVarsAttr.java new file mode 100644 index 000000000..646f5ddea --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/MethodTypeVarsAttr.java @@ -0,0 +1,33 @@ +package jadx.core.dex.attributes.nodes; + +import java.util.Set; + +import jadx.core.dex.attributes.AType; +import jadx.core.dex.attributes.IAttribute; +import jadx.core.dex.instructions.args.ArgType; + +/** + * Set of known type variables at current method + */ +public class MethodTypeVarsAttr implements IAttribute { + + private final Set typeVars; + + public MethodTypeVarsAttr(Set typeVars) { + this.typeVars = typeVars; + } + + public Set getTypeVars() { + return typeVars; + } + + @Override + public AType getType() { + return AType.METHOD_TYPE_VARS; + } + + @Override + public String toString() { + return "TYPE_VARS: " + typeVars; + } +} diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/args/ArgType.java b/jadx-core/src/main/java/jadx/core/dex/instructions/args/ArgType.java index ca24197e7..688a075e4 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/args/ArgType.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/args/ArgType.java @@ -4,6 +4,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Objects; +import java.util.function.Function; import org.jetbrains.annotations.NotNull; @@ -744,6 +745,36 @@ public abstract class ArgType { return false; } + /** + * Recursively visit all subtypes of this type. + * To exit return non-null value. + */ + public R visitTypes(Function visitor) { + R r = visitor.apply(this); + if (r != null) { + return r; + } + ArgType wildcardType = getWildcardType(); + if (wildcardType != null) { + return wildcardType.visitTypes(visitor); + } + if (isArray()) { + ArgType arrayElement = getArrayElement(); + if (arrayElement != null) { + return arrayElement.visitTypes(visitor); + } + } + if (isGeneric()) { + ArgType[] genericTypes = getGenericTypes(); + if (genericTypes != null) { + for (ArgType genericType : genericTypes) { + return genericType.visitTypes(visitor); + } + } + } + return null; + } + public static ArgType tryToResolveClassAlias(RootNode root, ArgType type) { if (!type.isObject() || type.isGenericType()) { return type; diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java b/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java index 8525106f1..6c725145a 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Consumer; import java.util.stream.Collectors; import org.jetbrains.annotations.NotNull; @@ -425,6 +426,16 @@ public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeN return parent == this ? this : parent.getTopParentClass(); } + public void visitParentClasses(Consumer consumer) { + ClassNode currentCls = this; + ClassNode parentCls = currentCls.getParentClass(); + while (parentCls != currentCls) { + consumer.accept(parentCls); + currentCls = parentCls; + parentCls = currentCls.getParentClass(); + } + } + public boolean hasNotGeneratedParent() { if (contains(AFlag.DONT_GENERATE)) { return true; diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/utils/TypeUtils.java b/jadx-core/src/main/java/jadx/core/dex/nodes/utils/TypeUtils.java index 2291ee2c0..f16f1e7c0 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/utils/TypeUtils.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/utils/TypeUtils.java @@ -2,19 +2,23 @@ package jadx.core.dex.nodes.utils; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; -import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import jadx.core.clsp.ClspClass; +import jadx.core.dex.attributes.AType; +import jadx.core.dex.attributes.nodes.MethodTypeVarsAttr; import jadx.core.dex.instructions.BaseInvokeNode; import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.GenericTypeParameter; import jadx.core.dex.nodes.IMethodDetails; +import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.RootNode; public class TypeUtils { @@ -24,7 +28,6 @@ public class TypeUtils { this.root = rootNode; } - @NotNull public List getClassGenerics(ArgType type) { ClassNode classNode = root.resolveClass(type); if (classNode != null) { @@ -38,6 +41,38 @@ public class TypeUtils { return generics == null ? Collections.emptyList() : generics; } + public Set getKnownTypeVarsAtMethod(MethodNode mth) { + MethodTypeVarsAttr typeVarsAttr = mth.get(AType.METHOD_TYPE_VARS); + if (typeVarsAttr != null) { + return typeVarsAttr.getTypeVars(); + } + Set typeVars = collectKnownTypeVarsAtMethod(mth); + mth.addAttr(new MethodTypeVarsAttr(typeVars)); + return typeVars; + } + + private static Set collectKnownTypeVarsAtMethod(MethodNode mth) { + Set typeVars = new HashSet<>(); + ClassNode declCls = mth.getParentClass(); + addTypeVarsFromCls(typeVars, declCls); + declCls.visitParentClasses(parent -> addTypeVarsFromCls(typeVars, parent)); + + for (GenericTypeParameter typeParameter : mth.getTypeParameters()) { + typeVars.add(typeParameter.getTypeVariable()); + } + return typeVars.isEmpty() ? Collections.emptySet() : typeVars; + } + + private static void addTypeVarsFromCls(Set typeVars, ClassNode parentCls) { + List typeParameters = parentCls.getGenericTypeParameters(); + if (typeParameters.isEmpty()) { + return; + } + for (GenericTypeParameter typeParameter : typeParameters) { + typeVars.add(typeParameter.getTypeVariable()); + } + } + /** * Replace generic types in {@code typeWithGeneric} using instance types *
diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/debuginfo/DebugInfoApplyVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/debuginfo/DebugInfoApplyVisitor.java index f893018f3..bc7734330 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/debuginfo/DebugInfoApplyVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/debuginfo/DebugInfoApplyVisitor.java @@ -143,7 +143,7 @@ public class DebugInfoApplyVisitor extends AbstractVisitor { } public static void applyDebugInfo(MethodNode mth, SSAVar ssaVar, ArgType type, String varName) { - TypeUpdateResult result = mth.root().getTypeUpdate().applyWithWiderAllow(ssaVar, type); + TypeUpdateResult result = mth.root().getTypeUpdate().applyWithWiderAllow(mth, ssaVar, type); if (result == TypeUpdateResult.REJECT) { if (Consts.DEBUG) { LOG.debug("Reject debug info of type: {} and name: '{}' for {}, mth: {}", type, varName, ssaVar, mth); diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java index 0d51c31da..c408f0634 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java @@ -72,7 +72,7 @@ public final class TypeInferenceVisitor extends AbstractVisitor { if (mth.isNoCode()) { return; } - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { LOG.info("Start type inference in method: {}", mth); } if (resolveTypes(mth)) { @@ -103,20 +103,21 @@ public final class TypeInferenceVisitor extends AbstractVisitor { /** * Guess type from usage and try to set it to current variable - * and all connected instructions with {@link TypeUpdate#apply(SSAVar, ArgType)} + * and all connected instructions with {@link TypeUpdate#apply(MethodNode, SSAVar, ArgType)} */ private boolean runTypePropagation(MethodNode mth) { + List ssaVars = mth.getSVars(); // collect initial type bounds from assign and usages` - mth.getSVars().forEach(this::attachBounds); - mth.getSVars().forEach(this::mergePhiBounds); + ssaVars.forEach(this::attachBounds); + ssaVars.forEach(this::mergePhiBounds); // start initial type propagation - mth.getSVars().forEach(this::setImmutableType); - mth.getSVars().forEach(this::setBestType); + ssaVars.forEach(var -> setImmutableType(mth, var)); + ssaVars.forEach(var -> setBestType(mth, var)); // try other types if type is still unknown boolean resolved = true; - for (SSAVar var : mth.getSVars()) { + for (SSAVar var : ssaVars) { ArgType type = var.getTypeInfo().getType(); if (!type.isTypeKnown() && !var.isTypeImmutable() @@ -131,7 +132,7 @@ public final class TypeInferenceVisitor extends AbstractVisitor { TypeSearch typeSearch = new TypeSearch(mth); try { if (!typeSearch.run()) { - mth.addWarn("Multi-variable type inference failed"); + mth.addWarnComment("Multi-variable type inference failed"); } for (SSAVar var : mth.getSVars()) { if (!var.getTypeInfo().getType().isTypeKnown()) { @@ -140,50 +141,44 @@ public final class TypeInferenceVisitor extends AbstractVisitor { } return true; } catch (Exception e) { - mth.addWarn("Multi-variable type inference failed. Error: " + Utils.getStackTrace(e)); + mth.addWarnComment("Multi-variable type inference failed. Error: " + Utils.getStackTrace(e)); return false; } } - private boolean setImmutableType(SSAVar ssaVar) { + private void setImmutableType(MethodNode mth, SSAVar ssaVar) { try { ArgType immutableType = ssaVar.getImmutableType(); if (immutableType != null) { - return applyImmutableType(ssaVar, immutableType); + applyImmutableType(mth, ssaVar, immutableType); } - return false; } catch (Exception e) { LOG.error("Failed to set immutable type for var: {}", ssaVar, e); - return false; } } - private boolean setBestType(SSAVar ssaVar) { + private boolean setBestType(MethodNode mth, SSAVar ssaVar) { try { - return calculateFromBounds(ssaVar); + return calculateFromBounds(mth, ssaVar); } catch (Exception e) { LOG.error("Failed to calculate best type for var: {}", ssaVar, e); return false; } } - private boolean applyImmutableType(SSAVar ssaVar, ArgType initType) { - TypeUpdateResult result = typeUpdate.apply(ssaVar, initType); - if (result == TypeUpdateResult.REJECT) { - if (Consts.DEBUG) { - LOG.info("Reject initial immutable type {} for {}", initType, ssaVar); - } - return false; + private void applyImmutableType(MethodNode mth, SSAVar ssaVar, ArgType initType) { + TypeUpdateResult result = typeUpdate.apply(mth, ssaVar, initType); + if (Consts.DEBUG_TYPE_INFERENCE && result == TypeUpdateResult.REJECT) { + LOG.info("Reject initial immutable type {} for {}", initType, ssaVar); } - return result == TypeUpdateResult.CHANGED; } - private boolean calculateFromBounds(SSAVar ssaVar) { + private boolean calculateFromBounds(MethodNode mth, SSAVar ssaVar) { TypeInfo typeInfo = ssaVar.getTypeInfo(); Set bounds = typeInfo.getBounds(); Optional bestTypeOpt = selectBestTypeFromBounds(bounds); if (!bestTypeOpt.isPresent()) { - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { LOG.warn("Failed to select best type from bounds, count={} : ", bounds.size()); for (ITypeBound bound : bounds) { LOG.warn(" {}", bound); @@ -192,9 +187,9 @@ public final class TypeInferenceVisitor extends AbstractVisitor { return false; } ArgType candidateType = bestTypeOpt.get(); - TypeUpdateResult result = typeUpdate.apply(ssaVar, candidateType); + TypeUpdateResult result = typeUpdate.apply(mth, ssaVar, candidateType); if (result == TypeUpdateResult.REJECT) { - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { if (ssaVar.getTypeInfo().getType().equals(candidateType)) { LOG.info("Same type rejected: {} -> {}, bounds: {}", ssaVar, candidateType, bounds); } else if (candidateType.isTypeKnown()) { @@ -235,7 +230,11 @@ public final class TypeInferenceVisitor extends AbstractVisitor { } private void addBound(TypeInfo typeInfo, ITypeBound bound) { - if (bound != null && bound.getType() != ArgType.UNKNOWN) { + if (bound == null) { + return; + } + if (bound instanceof ITypeBoundDynamic + || bound.getType() != ArgType.UNKNOWN) { typeInfo.getBounds().add(bound); } } @@ -333,10 +332,10 @@ public final class TypeInferenceVisitor extends AbstractVisitor { return new TypeBoundInvokeUse(root, invoke, regArg, argType); } - private boolean tryPossibleTypes(SSAVar var, ArgType type) { + private boolean tryPossibleTypes(MethodNode mth, SSAVar var, ArgType type) { List types = makePossibleTypesList(type); for (ArgType candidateType : types) { - TypeUpdateResult result = typeUpdate.apply(var, candidateType); + TypeUpdateResult result = typeUpdate.apply(mth, var, candidateType); if (result == TypeUpdateResult.CHANGED) { return true; } @@ -362,11 +361,11 @@ public final class TypeInferenceVisitor extends AbstractVisitor { private boolean tryDeduceType(MethodNode mth, SSAVar var, @Nullable ArgType type) { // try best type from bounds again - if (setBestType(var)) { + if (setBestType(mth, var)) { return true; } // try all possible types (useful for primitives) - if (type != null && tryPossibleTypes(var, type)) { + if (type != null && tryPossibleTypes(mth, var, type)) { return true; } // for objects try super types @@ -412,7 +411,7 @@ public final class TypeInferenceVisitor extends AbstractVisitor { private boolean checkRawType(MethodNode mth, SSAVar var, ArgType objType) { if (objType.isObject() && objType.containsGeneric()) { ArgType rawType = ArgType.object(objType.getObject()); - TypeUpdateResult result = typeUpdate.applyWithWiderAllow(var, rawType); + TypeUpdateResult result = typeUpdate.applyWithWiderAllow(mth, var, rawType); return result == TypeUpdateResult.CHANGED; } return false; @@ -575,7 +574,7 @@ public final class TypeInferenceVisitor extends AbstractVisitor { for (ArgType objType : objTypes) { for (String ancestor : clsp.getSuperTypes(objType.getObject())) { ArgType ancestorType = ArgType.object(ancestor); - TypeUpdateResult result = typeUpdate.applyWithWiderAllow(var, ancestorType); + TypeUpdateResult result = typeUpdate.applyWithWiderAllow(mth, var, ancestorType); if (result == TypeUpdateResult.CHANGED) { return true; } @@ -588,7 +587,9 @@ public final class TypeInferenceVisitor extends AbstractVisitor { if (var.getTypeInfo().getType() == ArgType.BOOLEAN) { for (ITypeBound bound : var.getTypeInfo().getBounds()) { if (bound.getBound() == BoundEnum.USE - && bound.getType().isPrimitive() && bound.getType() != ArgType.BOOLEAN) { + && bound.getType().isPrimitive() + && bound.getType() != ArgType.BOOLEAN + && bound.getArg() != null) { InsnNode insn = bound.getArg().getParentInsn(); if (insn == null || insn.getType() == InsnType.CAST) { continue; @@ -612,8 +613,10 @@ public final class TypeInferenceVisitor extends AbstractVisitor { } BlockNode blockNode = BlockUtils.getBlockByInsn(mth, insn); - List insnList = blockNode.getInstructions(); - insnList.add(insnList.indexOf(insn), castNode); + if (blockNode != null) { + List insnList = blockNode.getInstructions(); + insnList.add(insnList.indexOf(insn), castNode); + } } } } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeSearch.java b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeSearch.java index 3a4e139a6..dd1b3df4c 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeSearch.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeSearch.java @@ -63,7 +63,7 @@ public class TypeSearch { } else { search(vars); searchSuccess = fullCheck(vars); - if (Consts.DEBUG && !searchSuccess) { + if (Consts.DEBUG_TYPE_INFERENCE && !searchSuccess) { LOG.warn("Multi-variable search failed in {}", mth); } } @@ -86,7 +86,7 @@ public class TypeSearch { // exclude unknown variables continue; } - TypeUpdateResult res = typeUpdate.applyWithWiderIgnSame(var.getVar(), var.getCurrentType()); + TypeUpdateResult res = typeUpdate.applyWithWiderIgnSame(mth, var.getVar(), var.getCurrentType()); if (res == TypeUpdateResult.REJECT) { mth.addComment("JADX DEBUG: Multi-variable search result rejected for " + var); applySuccess = false; @@ -97,7 +97,7 @@ public class TypeSearch { private boolean search(List vars) { int len = vars.size(); - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { LOG.debug("Run search for {} vars: ", len); StringBuilder sb = new StringBuilder(); long count = 1; diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java index fa18c5a3e..a690b69ba 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java @@ -5,6 +5,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Function; +import java.util.function.Supplier; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -22,6 +24,7 @@ import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.nodes.IMethodDetails; import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.RootNode; import jadx.core.dex.nodes.utils.TypeUtils; import jadx.core.utils.exceptions.JadxOverflowException; @@ -47,30 +50,30 @@ public final class TypeUpdate { /** * Perform recursive type checking and type propagation for all related variables */ - public TypeUpdateResult apply(SSAVar ssaVar, ArgType candidateType) { - return apply(ssaVar, candidateType, TypeUpdateFlags.FLAGS_EMPTY); + public TypeUpdateResult apply(MethodNode mth, SSAVar ssaVar, ArgType candidateType) { + return apply(mth, ssaVar, candidateType, TypeUpdateFlags.FLAGS_EMPTY); } /** * Allow wider types for apply from debug info and some special cases */ - public TypeUpdateResult applyWithWiderAllow(SSAVar ssaVar, ArgType candidateType) { - return apply(ssaVar, candidateType, TypeUpdateFlags.FLAGS_WIDER); + public TypeUpdateResult applyWithWiderAllow(MethodNode mth, SSAVar ssaVar, ArgType candidateType) { + return apply(mth, ssaVar, candidateType, TypeUpdateFlags.FLAGS_WIDER); } /** * Force type setting */ - public TypeUpdateResult applyWithWiderIgnSame(SSAVar ssaVar, ArgType candidateType) { - return apply(ssaVar, candidateType, TypeUpdateFlags.FLAGS_WIDER_IGNSAME); + public TypeUpdateResult applyWithWiderIgnSame(MethodNode mth, SSAVar ssaVar, ArgType candidateType) { + return apply(mth, ssaVar, candidateType, TypeUpdateFlags.FLAGS_WIDER_IGNSAME); } - private TypeUpdateResult apply(SSAVar ssaVar, ArgType candidateType, TypeUpdateFlags flags) { + private TypeUpdateResult apply(MethodNode mth, SSAVar ssaVar, ArgType candidateType, TypeUpdateFlags flags) { if (candidateType == null || !candidateType.isTypeKnown()) { return REJECT; } - TypeUpdateInfo updateInfo = new TypeUpdateInfo(flags); + TypeUpdateInfo updateInfo = new TypeUpdateInfo(mth, flags); TypeUpdateResult result = updateTypeChecked(updateInfo, ssaVar.getAssign(), candidateType); if (result == REJECT) { return result; @@ -79,7 +82,7 @@ public final class TypeUpdate { if (updates.isEmpty()) { return SAME; } - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { LOG.debug("Applying types for {} -> {}", ssaVar, candidateType); updates.forEach(updateEntry -> LOG.debug(" {} -> {}, insn: {}", updateEntry.getType(), updateEntry.getArg(), updateEntry.getArg().getParentInsn())); @@ -102,13 +105,13 @@ public final class TypeUpdate { if (compareResult == TypeCompareEnum.EQUAL) { return SAME; } - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { LOG.debug("Type rejected for {} due to conflict: candidate={}, current={}", arg, candidateType, currentType); } return REJECT; } if (compareResult.isWider() && !updateInfo.getFlags().isAllowWider()) { - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { LOG.debug("Type rejected for {}: candidate={} is wider than current={}", arg, candidateType, currentType); } return REJECT; @@ -124,13 +127,13 @@ public final class TypeUpdate { TypeInfo typeInfo = ssaVar.getTypeInfo(); ArgType immutableType = ssaVar.getImmutableType(); if (immutableType != null && !Objects.equals(immutableType, candidateType)) { - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { LOG.info("Reject change immutable type {} to {} for {}", immutableType, candidateType, ssaVar); } return REJECT; } if (!inBounds(updateInfo, typeInfo.getBounds(), candidateType)) { - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { LOG.debug("Reject type '{}' for {} by bounds: {}", candidateType, ssaVar, typeInfo.getBounds()); } return REJECT; @@ -164,7 +167,7 @@ public final class TypeUpdate { } updateInfo.requestUpdate(arg, candidateType); if (updateInfo.getUpdates().size() > 500) { - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { LOG.error("Type update error: too deep update tree"); } return REJECT; @@ -287,55 +290,89 @@ public final class TypeUpdate { // TODO: implement backward type propagation (from result to instance) return SAME; } - if (invoke.getInstanceArg() == arg && candidateType.containsGeneric()) { - // resolve result and arg types from generic instance type + if (invoke.getInstanceArg() == arg) { IMethodDetails methodDetails = root.getMethodUtils().getMethodDetails(invoke); if (methodDetails == null) { return SAME; } TypeUtils typeUtils = root.getTypeUtils(); + Set knownTypeVars = typeUtils.getKnownTypeVarsAtMethod(updateInfo.getMth()); Map typeVarsMap = typeUtils.getTypeVariablesMapping(candidateType); - if (typeVarsMap.isEmpty()) { - return SAME; - } - boolean allSame = true; - if (invoke.getResult() != null) { - ArgType returnType = typeUtils.replaceTypeVariablesUsingMap(methodDetails.getReturnType(), typeVarsMap); - if (returnType != null) { - TypeUpdateResult result = updateTypeChecked(updateInfo, invoke.getResult(), returnType); - if (result == REJECT) { - return REJECT; - } - if (result == CHANGED) { - allSame = false; - } - } - } - - int argOffset = invoke.getFirstArgOffset(); + ArgType returnType = methodDetails.getReturnType(); List argTypes = methodDetails.getArgTypes(); int argsCount = argTypes.size(); - for (int i = 0; i < argsCount; i++) { - ArgType genericArgType = argTypes.get(i); - ArgType resultArgType = typeUtils.replaceClassGenerics(candidateType, genericArgType); - if (resultArgType != null) { - InsnArg invokeArg = invoke.getArg(argOffset + i); - TypeUpdateResult result = updateTypeChecked(updateInfo, invokeArg, resultArgType); - if (result == REJECT) { - return REJECT; - } - if (result == CHANGED) { - allSame = false; - } - } + if (typeVarsMap.isEmpty()) { + // generics can't be resolved => use as is + return applyInvokeTypes(updateInfo, invoke, argsCount, knownTypeVars, () -> returnType, argTypes::get); } - return allSame ? SAME : CHANGED; + // resolve types before apply + return applyInvokeTypes(updateInfo, invoke, argsCount, knownTypeVars, + () -> typeUtils.replaceTypeVariablesUsingMap(returnType, typeVarsMap), + argNum -> typeUtils.replaceClassGenerics(candidateType, argTypes.get(argNum))); } return SAME; } + private TypeUpdateResult applyInvokeTypes(TypeUpdateInfo updateInfo, BaseInvokeNode invoke, int argsCount, + Set knownTypeVars, Supplier getReturnType, Function getArgType) { + boolean allSame = true; + RegisterArg resultArg = invoke.getResult(); + if (resultArg != null && !resultArg.isTypeImmutable()) { + ArgType returnType = checkType(knownTypeVars, getReturnType.get()); + if (returnType != null) { + TypeUpdateResult result = updateTypeChecked(updateInfo, resultArg, returnType); + if (result == REJECT) { + TypeCompareEnum compare = comparator.compareTypes(returnType, resultArg.getType()); + if (compare.isWider()) { + return REJECT; + } + } + if (result == CHANGED) { + allSame = false; + } + } + } + int argOffset = invoke.getFirstArgOffset(); + for (int i = 0; i < argsCount; i++) { + InsnArg invokeArg = invoke.getArg(argOffset + i); + if (!invokeArg.isTypeImmutable()) { + ArgType argType = checkType(knownTypeVars, getArgType.apply(i)); + if (argType != null) { + TypeUpdateResult result = updateTypeChecked(updateInfo, invokeArg, argType); + if (result == REJECT) { + TypeCompareEnum compare = comparator.compareTypes(argType, invokeArg.getType()); + if (compare.isNarrow()) { + return REJECT; + } + } + if (result == CHANGED) { + allSame = false; + } + } + } + } + return allSame ? SAME : CHANGED; + } + + @Nullable + private ArgType checkType(Set knownTypeVars, @Nullable ArgType type) { + if (type == null) { + return null; + } + if (type.containsTypeVariable()) { + if (knownTypeVars.isEmpty()) { + return null; + } + Boolean hasUnknown = type.visitTypes(t -> t.isGenericType() && !knownTypeVars.contains(t) ? Boolean.TRUE : null); + if (hasUnknown != null) { + return null; + } + } + return type; + } + private TypeUpdateResult sameFirstArgListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) { InsnArg changeArg = isAssign(insn, arg) ? insn.getArg(0) : insn.getResult(); return updateTypeChecked(updateInfo, changeArg, candidateType); @@ -356,7 +393,7 @@ public final class TypeUpdate { TypeUpdateResult result = updateTypeChecked(updateInfo, changeArg, candidateType); if (result == SAME && !correctType) { - if (Consts.DEBUG) { + if (Consts.DEBUG_TYPE_INFERENCE) { LOG.debug("Move insn types mismatch: {} -> {}, change arg: {}, insn: {}", candidateType, changeArg.getType(), changeArg, insn); } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdateInfo.java b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdateInfo.java index 1b2a84ced..b2e8911cc 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdateInfo.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdateInfo.java @@ -5,12 +5,15 @@ import java.util.List; import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.InsnArg; +import jadx.core.dex.nodes.MethodNode; public class TypeUpdateInfo { + private final MethodNode mth; private final TypeUpdateFlags flags; private final List updates = new ArrayList<>(); - public TypeUpdateInfo(TypeUpdateFlags flags) { + public TypeUpdateInfo(MethodNode mth, TypeUpdateFlags flags) { + this.mth = mth; this.flags = flags; } @@ -50,6 +53,10 @@ public class TypeUpdateInfo { updates.removeIf(updateEntry -> updateEntry.getArg() == arg); } + public MethodNode getMth() { + return mth; + } + public List getUpdates() { return updates; } diff --git a/jadx-core/src/test/java/jadx/tests/integration/invoke/TestCastInOverloadedInvoke.java b/jadx-core/src/test/java/jadx/tests/integration/invoke/TestCastInOverloadedInvoke.java index 1a081f927..a175d7baf 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/invoke/TestCastInOverloadedInvoke.java +++ b/jadx-core/src/test/java/jadx/tests/integration/invoke/TestCastInOverloadedInvoke.java @@ -64,7 +64,7 @@ public class TestCastInOverloadedInvoke extends IntegrationTest { ClassNode cls = getClassNode(TestCls.class); String code = cls.getCode().toString(); - assertThat(code, containsOne("call((ArrayList) new ArrayList());")); + assertThat(code, containsOne("call(new ArrayList<>());")); assertThat(code, containsOne("call((List) new ArrayList());")); assertThat(code, containsOne("call((String) obj);")); @@ -76,9 +76,6 @@ public class TestCastInOverloadedInvoke extends IntegrationTest { ClassNode cls = getClassNode(TestCls.class); String code = cls.getCode().toString(); - assertThat(code, containsOne("call(new ArrayList<>());")); assertThat(code, containsOne("call((List) new ArrayList());")); - - assertThat(code, containsOne("call((String) obj);")); } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/invoke/TestHierarchyOverloadedInvoke.java b/jadx-core/src/test/java/jadx/tests/integration/invoke/TestHierarchyOverloadedInvoke.java index 7569de41a..30b1ffb2b 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/invoke/TestHierarchyOverloadedInvoke.java +++ b/jadx-core/src/test/java/jadx/tests/integration/invoke/TestHierarchyOverloadedInvoke.java @@ -5,7 +5,6 @@ import java.util.List; import org.junit.jupiter.api.Test; -import jadx.NotYetImplemented; import jadx.core.dex.nodes.ClassNode; import jadx.tests.api.IntegrationTest; @@ -83,18 +82,9 @@ public class TestHierarchyOverloadedInvoke extends IntegrationTest { ClassNode cls = getClassNode(TestCls.class); String code = cls.getCode().toString(); - assertThat(code, containsOne("b.call((ArrayList) new ArrayList());")); + assertThat(code, containsOne("b.call(new ArrayList<>());")); assertThat(code, containsOne("b.call((List) new ArrayList());")); assertThat(code, containsOne("b.call((String) obj);")); } - - @NotYetImplemented - @Test - public void test2() { - ClassNode cls = getClassNode(TestCls.class); - String code = cls.getCode().toString(); - - assertThat(code, containsOne("b.call(new ArrayList<>());")); - } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/types/TestGenerics6.java b/jadx-core/src/test/java/jadx/tests/integration/types/TestGenerics6.java new file mode 100644 index 000000000..d4913ef54 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/types/TestGenerics6.java @@ -0,0 +1,70 @@ +package jadx.tests.integration.types; + +import java.util.Iterator; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestGenerics6 extends IntegrationTest { + + public static class TestCls implements Iterable> { + public V test(K key, V v) { + Entry entry = get(key); + if (entry != null) { + return entry.mValue; + } + put(key, v); + return null; + } + + protected Entry get(K k) { + return null; + } + + protected Entry put(K key, V v) { + return null; + } + + @Override + public Iterator> iterator() { + return null; + } + + static class Entry implements Map.Entry { + final V mValue; + + Entry(K key, V value) { + this.mValue = value; + } + + @Override + public K getKey() { + return null; + } + + @Override + public V getValue() { + return null; + } + + @Override + public V setValue(V value) { + return null; + } + } + + } + + @Test + public void test() { + noDebugInfo(); + assertThat(getClassNode(TestCls.class)) + .code() + .doesNotContain("Entry entry = get(k);") + .containsOne("Entry entry = get(k);"); + } +}