From 404136cd72cca845e7cccbc57ffdd6aec31b572e Mon Sep 17 00:00:00 2001 From: Skylot Date: Sun, 10 May 2020 14:20:27 +0100 Subject: [PATCH] fix: improve type inference for generics in invoke insn (#927) --- .../typeinference/TypeBoundInvokeUse.java | 78 +++++++++++++++++ .../typeinference/TypeInferenceVisitor.java | 26 ++++-- .../visitors/typeinference/TypeUpdate.java | 83 ++++++++++++++----- .../integration/types/TestGenerics5.java | 44 ++++++++++ 4 files changed, 205 insertions(+), 26 deletions(-) create mode 100644 jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeBoundInvokeUse.java create mode 100644 jadx-core/src/test/java/jadx/tests/integration/types/TestGenerics5.java diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeBoundInvokeUse.java b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeBoundInvokeUse.java new file mode 100644 index 000000000..7dfa14ab9 --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeBoundInvokeUse.java @@ -0,0 +1,78 @@ +package jadx.core.dex.visitors.typeinference; + +import jadx.core.dex.instructions.BaseInvokeNode; +import jadx.core.dex.instructions.args.ArgType; +import jadx.core.dex.instructions.args.RegisterArg; +import jadx.core.dex.nodes.RootNode; + +/** + * Special dynamic bound for invoke with generics. + * Arguments bound type calculated using instance generic type. + */ +public final class TypeBoundInvokeUse implements ITypeBoundDynamic { + private final RootNode root; + private final BaseInvokeNode invokeNode; + private final RegisterArg arg; + private final ArgType genericArgType; + + public TypeBoundInvokeUse(RootNode root, BaseInvokeNode invokeNode, RegisterArg arg, ArgType genericArgType) { + this.root = root; + this.invokeNode = invokeNode; + this.arg = arg; + this.genericArgType = genericArgType; + } + + @Override + public BoundEnum getBound() { + return BoundEnum.USE; + } + + @Override + public ArgType getType(TypeUpdateInfo updateInfo) { + return getArgType(updateInfo.getType(invokeNode.getInstanceArg()), updateInfo.getType(arg)); + } + + @Override + public ArgType getType() { + return getArgType(invokeNode.getInstanceArg().getType(), arg.getType()); + } + + private ArgType getArgType(ArgType instanceType, ArgType argType) { + ArgType resultGeneric = root.getTypeUtils().replaceClassGenerics(instanceType, genericArgType); + if (resultGeneric != null) { + return resultGeneric; + } + return argType; + } + + @Override + public RegisterArg getArg() { + return arg; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TypeBoundInvokeUse that = (TypeBoundInvokeUse) o; + return invokeNode.equals(that.invokeNode); + } + + @Override + public int hashCode() { + return invokeNode.hashCode(); + } + + @Override + public String toString() { + return "InvokeAssign{" + invokeNode.getCallMth().getShortId() + + ", argType=" + genericArgType + + ", currentType=" + getType() + + ", instanceArg=" + invokeNode.getInstanceArg() + + '}'; + } +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java index 87ec0b3ec..ac28c858e 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeInferenceVisitor.java @@ -298,17 +298,31 @@ public final class TypeInferenceVisitor extends AbstractVisitor { return null; } if (insn instanceof BaseInvokeNode) { - IMethodDetails methodDetails = root.getMethodUtils().getMethodDetails((BaseInvokeNode) insn); - if (methodDetails != null) { - if (methodDetails.getArgTypes().stream().anyMatch(ArgType::containsTypeVariable)) { - // don't add const bound for generic type variables - return null; - } + TypeBoundInvokeUse invokeUseBound = makeInvokeUseBound(regArg, (BaseInvokeNode) insn); + if (invokeUseBound != null) { + return invokeUseBound; } } return new TypeBoundConst(BoundEnum.USE, regArg.getInitType(), regArg); } + private TypeBoundInvokeUse makeInvokeUseBound(RegisterArg regArg, BaseInvokeNode invoke) { + InsnArg instanceArg = invoke.getInstanceArg(); + if (instanceArg == null || instanceArg == regArg) { + return null; + } + IMethodDetails methodDetails = root.getMethodUtils().getMethodDetails(invoke); + if (methodDetails == null) { + return null; + } + int argIndex = invoke.getArgIndex(regArg) - invoke.getFirstArgOffset(); + ArgType argType = methodDetails.getArgTypes().get(argIndex); + if (!argType.containsTypeVariable()) { + return null; + } + return new TypeBoundInvokeUse(root, invoke, regArg, argType); + } + private boolean tryPossibleTypes(SSAVar var, ArgType type) { List types = makePossibleTypesList(type); for (ArgType candidateType : types) { diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java index 35f6ccb51..47bf57141 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java @@ -12,6 +12,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import jadx.core.Consts; +import jadx.core.dex.instructions.IndexInsnNode; import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InvokeNode; import jadx.core.dex.instructions.args.ArgType; @@ -19,8 +20,10 @@ import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.PrimitiveType; import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.SSAVar; +import jadx.core.dex.nodes.IMethodDetails; import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.RootNode; +import jadx.core.dex.nodes.utils.TypeUtils; import jadx.core.utils.exceptions.JadxOverflowException; import jadx.core.utils.exceptions.JadxRuntimeException; @@ -278,27 +281,58 @@ public final class TypeUpdate { } private TypeUpdateResult invokeListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) { - if (insn.getResult() == null) { + InvokeNode invoke = (InvokeNode) insn; + if (isAssign(invoke, arg)) { + // TODO: implement backward type propagation (from result to instance) return SAME; } - if (candidateType.containsTypeVariable()) { - InvokeNode invokeNode = (InvokeNode) insn; - if (isAssign(insn, arg)) { - // TODO: implement backward type propagation (from result to instance) + if (invoke.getInstanceArg() == arg && candidateType.containsGeneric()) { + // resolve result and arg types from generic instance type + IMethodDetails methodDetails = root.getMethodUtils().getMethodDetails(invoke); + if (methodDetails == null) { return SAME; - } else { - ArgType returnType = root.getMethodUtils().getMethodGenericReturnType(invokeNode); - if (returnType == null) { - return SAME; - } - ArgType resultGeneric = root.getTypeUtils().replaceClassGenerics(candidateType, returnType); - if (resultGeneric == null) { - return SAME; - } - return updateTypeChecked(updateInfo, insn.getResult(), resultGeneric); } + TypeUtils typeUtils = root.getTypeUtils(); + Map typeVarsMap = typeUtils.getTypeVariablesMapping(candidateType); + if (typeVarsMap.isEmpty()) { + return SAME; + } + + boolean allSame = true; + if (invoke.getResult() != null) { + ArgType returnType = typeUtils.replaceTypeVariablesUsingMap(methodDetails.getReturnType(), typeVarsMap); + if (returnType != null) { + TypeUpdateResult result = updateTypeChecked(updateInfo, invoke.getResult(), returnType); + if (result == REJECT) { + return REJECT; + } + if (result == CHANGED) { + allSame = false; + } + } + } + + int argOffset = invoke.getFirstArgOffset(); + List argTypes = methodDetails.getArgTypes(); + int argsCount = argTypes.size(); + for (int i = 0; i < argsCount; i++) { + ArgType genericArgType = argTypes.get(i); + ArgType resultArgType = typeUtils.replaceClassGenerics(candidateType, genericArgType); + if (resultArgType != null) { + InsnArg invokeArg = invoke.getArg(argOffset + i); + TypeUpdateResult result = updateTypeChecked(updateInfo, invokeArg, resultArgType); + if (result == REJECT) { + return REJECT; + } + if (result == CHANGED) { + allSame = false; + } + } + } + return allSame ? SAME : CHANGED; } return SAME; + } private TypeUpdateResult sameFirstArgListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) { @@ -377,12 +411,21 @@ public final class TypeUpdate { } private TypeUpdateResult checkCastListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) { - if (!isAssign(insn, arg)) { - return SAME; + IndexInsnNode checkCast = (IndexInsnNode) insn; + if (isAssign(insn, arg)) { + InsnArg insnArg = insn.getArg(0); + TypeUpdateResult result = updateTypeChecked(updateInfo, insnArg, candidateType); + return result == REJECT ? SAME : result; } - InsnArg insnArg = insn.getArg(0); - TypeUpdateResult result = updateTypeChecked(updateInfo, insnArg, candidateType); - return result == REJECT ? SAME : result; + if (candidateType.containsGeneric()) { + ArgType castType = (ArgType) checkCast.getIndex(); + TypeCompareEnum compResult = comparator.compareTypes(candidateType, castType); + if (compResult == TypeCompareEnum.NARROW_BY_GENERIC) { + // propagate generic type to result + return updateTypeChecked(updateInfo, checkCast.getResult(), candidateType); + } + } + return SAME; } private TypeUpdateResult arrayGetListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) { diff --git a/jadx-core/src/test/java/jadx/tests/integration/types/TestGenerics5.java b/jadx-core/src/test/java/jadx/tests/integration/types/TestGenerics5.java new file mode 100644 index 000000000..2621cc19f --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/types/TestGenerics5.java @@ -0,0 +1,44 @@ +package jadx.tests.integration.types; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestGenerics5 extends IntegrationTest { + + public static class TestCls { + private InheritableThreadLocal> inheritableThreadLocal; + + public void put(String key, String val) { + if (key == null) { + throw new IllegalArgumentException("key cannot be null"); + } + Map map = this.inheritableThreadLocal.get(); + if (map == null) { + map = new HashMap<>(); + this.inheritableThreadLocal.set(map); + } + map.put(key, val); + } + + public void remove(String key) { + Map map = this.inheritableThreadLocal.get(); + if (map != null) { + map.remove(key); + } + } + } + + @Test + public void test() { + noDebugInfo(); + assertThat(getClassNode(TestCls.class)) + .code() + .countString(2, "Map map = this.inheritableThreadLocal.get();"); + } +}