fix: mark override methods and fix return type (#913)

This commit is contained in:
Skylot
2020-05-11 20:42:06 +01:00
parent 3968222744
commit 0692464b85
18 changed files with 328 additions and 37 deletions
@@ -57,6 +57,8 @@ public class Jadx {
passes.add(new BlockExceptionHandler());
passes.add(new BlockFinish());
passes.add(new OverrideMethodVisitor());
passes.add(new SSATransform());
passes.add(new MoveInlineVisitor());
passes.add(new ConstructorVisitor());
@@ -10,6 +10,7 @@ import java.util.Map;
import java.util.Set;
import java.util.WeakHashMap;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -29,7 +30,7 @@ public class ClspGraph {
private static final Logger LOG = LoggerFactory.getLogger(ClspGraph.class);
private final RootNode root;
private final Map<String, Set<String>> ancestorCache = Collections.synchronizedMap(new WeakHashMap<>());
private final Map<String, Set<String>> superTypesCache = Collections.synchronizedMap(new WeakHashMap<>());
private Map<String, ClspClass> nameMap;
private final Set<String> missingClasses = new HashSet<>();
@@ -116,7 +117,7 @@ public class ClspGraph {
* @return {@code clsName} instanceof {@code implClsName}
*/
public boolean isImplements(String clsName, String implClsName) {
Set<String> anc = getAncestors(clsName);
Set<String> anc = getSuperTypes(clsName);
return anc.contains(implClsName);
}
@@ -142,7 +143,7 @@ public class ClspGraph {
if (isImplements(clsName, implClsName)) {
return implClsName;
}
Set<String> anc = getAncestors(clsName);
Set<String> anc = getSuperTypes(clsName);
return searchCommonParent(anc, cls);
}
@@ -163,35 +164,42 @@ public class ClspGraph {
return null;
}
public Set<String> getAncestors(String clsName) {
Set<String> result = ancestorCache.get(clsName);
if (result != null) {
return result;
public Set<String> getSuperTypes(String clsName) {
Set<String> fromCache = superTypesCache.get(clsName);
if (fromCache != null) {
return fromCache;
}
ClspClass cls = nameMap.get(clsName);
if (cls == null) {
missingClasses.add(clsName);
return Collections.emptySet();
}
result = new HashSet<>();
addAncestorsNames(cls, result);
Set<String> result = new HashSet<>();
addSuperTypes(cls, result);
return putInSuperTypesCache(clsName, result);
}
@NotNull
private Set<String> putInSuperTypesCache(String clsName, Set<String> result) {
if (result.isEmpty()) {
result = Collections.emptySet();
Set<String> empty = Collections.emptySet();
superTypesCache.put(clsName, result);
return empty;
}
ancestorCache.put(clsName, result);
superTypesCache.put(clsName, result);
return result;
}
private void addAncestorsNames(ClspClass cls, Set<String> result) {
boolean isNew = result.add(cls.getName());
if (isNew) {
for (ArgType parentType : cls.getParents()) {
if (parentType == null) {
continue;
}
ClspClass parentCls = getClspClass(parentType);
if (parentCls != null) {
addAncestorsNames(parentCls, result);
private void addSuperTypes(ClspClass cls, Set<String> result) {
for (ArgType parentType : cls.getParents()) {
if (parentType == null) {
continue;
}
ClspClass parentCls = getClspClass(parentType);
if (parentCls != null) {
boolean isNew = result.add(parentCls.getName());
if (isNew) {
addSuperTypes(parentCls, result);
}
}
}
@@ -317,7 +317,7 @@ public class ClassGen {
public void addMethodCode(CodeWriter code, MethodNode mth) throws CodegenException {
CodeGenUtils.addComments(code, mth);
if (mth.getAccessFlags().isAbstract() || mth.getAccessFlags().isNative()) {
if (mth.isNoCode()) {
MethodGen mthGen = new MethodGen(this, mth);
mthGen.addDefinition(code);
code.add(';');
@@ -14,6 +14,7 @@ import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.annotations.MethodParameters;
import jadx.core.dex.attributes.nodes.JumpInfo;
import jadx.core.dex.attributes.nodes.MethodOverrideAttr;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.InsnType;
@@ -21,6 +22,7 @@ import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.CodeVar;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.IMethodDetails;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.trycatch.CatchAttr;
@@ -72,6 +74,7 @@ public class MethodGen {
code.attachDefinition(mth);
return false;
}
addOverrideAnnotation(code, mth);
annotationGen.addForMethod(code, mth);
AccessInfo clsAccFlags = mth.getParentClass().getAccessFlags();
@@ -146,6 +149,23 @@ public class MethodGen {
return true;
}
private void addOverrideAnnotation(CodeWriter code, MethodNode mth) {
MethodOverrideAttr overrideAttr = mth.get(AType.METHOD_OVERRIDE);
if (overrideAttr == null) {
return;
}
code.startLine("@Override");
code.add(" // ");
Iterator<IMethodDetails> it = overrideAttr.getOverrideList().iterator();
while (it.hasNext()) {
IMethodDetails methodDetails = it.next();
code.add(methodDetails.getMethodInfo().getDeclClass().getAliasFullName());
if (it.hasNext()) {
code.add(", ");
}
}
}
private void addMethodArguments(CodeWriter code, List<RegisterArg> args) {
MethodParameters paramsAnnotation = mth.get(AType.ANNOTATION_MTH_PARAMETERS);
int i = 0;
@@ -20,6 +20,7 @@ import jadx.core.dex.attributes.nodes.LocalVarsDebugInfoAttr;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.attributes.nodes.LoopLabelAttr;
import jadx.core.dex.attributes.nodes.MethodInlineAttr;
import jadx.core.dex.attributes.nodes.MethodOverrideAttr;
import jadx.core.dex.attributes.nodes.PhiListAttr;
import jadx.core.dex.attributes.nodes.RegDebugInfoAttr;
import jadx.core.dex.attributes.nodes.RenameReasonAttr;
@@ -63,6 +64,7 @@ public class AType<T extends IAttribute> {
public static final AType<MethodInlineAttr> METHOD_INLINE = new AType<>();
public static final AType<MethodParameters> ANNOTATION_MTH_PARAMETERS = new AType<>();
public static final AType<SkipMethodArgsAttr> SKIP_MTH_ARGS = new AType<>();
public static final AType<MethodOverrideAttr> METHOD_OVERRIDE = new AType<>();
// region
public static final AType<DeclareVariablesAttr> DECLARE_VARIABLES = new AType<>();
@@ -0,0 +1,30 @@
package jadx.core.dex.attributes.nodes;
import java.util.List;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.IAttribute;
import jadx.core.dex.nodes.IMethodDetails;
public class MethodOverrideAttr implements IAttribute {
private final List<IMethodDetails> overrideList;
public MethodOverrideAttr(List<IMethodDetails> overrideList) {
this.overrideList = overrideList;
}
public List<IMethodDetails> getOverrideList() {
return overrideList;
}
@Override
public AType<MethodOverrideAttr> getType() {
return AType.METHOD_OVERRIDE;
}
@Override
public String toString() {
return "METHOD_OVERRIDE: " + overrideList;
}
}
@@ -10,6 +10,7 @@ import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.typeinference.TypeCompareEnum;
import jadx.core.utils.Utils;
public abstract class ArgType {
@@ -578,11 +579,8 @@ public abstract class ArgType {
if (from.equals(to)) {
return false;
}
if (from.isObject() && to.isObject()
&& root.getClsp().isImplements(from.getObject(), to.getObject())) {
return false;
}
return true;
TypeCompareEnum result = root.getTypeUpdate().getTypeCompare().compareTypes(from, to);
return !result.isNarrow();
}
public static boolean isInstanceOf(RootNode root, ArgType type, ArgType of) {
@@ -180,6 +180,7 @@ public class MethodNode extends NotificationAttrNode implements IMethodDetails,
if (types == null) {
this.retType = mthInfo.getReturnType();
this.argTypes = mthInfo.getArgumentsTypes();
this.typeParameters = Collections.emptyList();
} else {
this.argTypes = Collections.unmodifiableList(types);
}
@@ -283,6 +284,10 @@ public class MethodNode extends NotificationAttrNode implements IMethodDetails,
return retType;
}
public void updateReturnType(ArgType type) {
this.retType = type;
}
public boolean isVoidReturn() {
return mthInfo.getReturnType().equals(ArgType.VOID);
}
@@ -50,13 +50,14 @@ public class TypeUtils {
*/
@Nullable
public ArgType replaceClassGenerics(ArgType instanceType, ArgType typeWithGeneric) {
if (typeWithGeneric != null) {
Map<ArgType, ArgType> replaceMap = getTypeVariablesMapping(instanceType);
if (!replaceMap.isEmpty()) {
return replaceTypeVariablesUsingMap(typeWithGeneric, replaceMap);
}
if (typeWithGeneric == null) {
return null;
}
return null;
Map<ArgType, ArgType> replaceMap = getTypeVariablesMapping(instanceType);
if (replaceMap.isEmpty()) {
return null;
}
return replaceTypeVariablesUsingMap(typeWithGeneric, replaceMap);
}
public Map<ArgType, ArgType> getTypeVariablesMapping(ArgType clsType) {
@@ -0,0 +1,148 @@
package jadx.core.dex.visitors;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import jadx.core.clsp.ClspClass;
import jadx.core.clsp.ClspMethod;
import jadx.core.dex.attributes.nodes.MethodOverrideAttr;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.IMethodDetails;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.typeinference.TypeCompare;
import jadx.core.dex.visitors.typeinference.TypeCompareEnum;
import jadx.core.dex.visitors.typeinference.TypeInferenceVisitor;
import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
name = "OverrideMethodVisitor",
desc = "Mark override methods and revert type erasure",
runBefore = {
TypeInferenceVisitor.class
}
)
public class OverrideMethodVisitor extends AbstractVisitor {
@Override
public boolean visit(ClassNode cls) throws JadxException {
RootNode root = cls.root();
List<ArgType> superTypes = collectSuperTypes(cls);
for (MethodNode mth : cls.getMethods()) {
if (mth.isConstructor()) {
continue;
}
String signature = mth.getMethodInfo().makeSignature(false);
List<IMethodDetails> overrideList = collectOverrideMethods(root, superTypes, signature);
if (!overrideList.isEmpty()) {
mth.addAttr(new MethodOverrideAttr(overrideList));
fixMethodReturnType(mth, overrideList, superTypes);
}
}
return true;
}
private List<IMethodDetails> collectOverrideMethods(RootNode root, List<ArgType> superTypes, String signature) {
List<IMethodDetails> overrideList = new ArrayList<>();
for (ArgType superType : superTypes) {
ClassNode classNode = root.resolveClass(superType);
if (classNode != null) {
for (MethodNode mth : classNode.getMethods()) {
String mthShortId = mth.getMethodInfo().getShortId();
if (mthShortId.startsWith(signature)) {
overrideList.add(mth);
}
}
} else {
ClspClass clsDetails = root.getClsp().getClsDetails(superType);
if (clsDetails != null) {
Map<String, ClspMethod> methodsMap = clsDetails.getMethodsMap();
for (Map.Entry<String, ClspMethod> entry : methodsMap.entrySet()) {
String mthShortId = entry.getKey();
if (mthShortId.startsWith(signature)) {
overrideList.add(entry.getValue());
}
}
}
}
}
return overrideList;
}
private List<ArgType> collectSuperTypes(ClassNode cls) {
Map<String, ArgType> superTypes = new HashMap<>();
collectSuperTypes(cls, superTypes);
return new ArrayList<>(superTypes.values());
}
private void collectSuperTypes(ClassNode cls, Map<String, ArgType> superTypes) {
RootNode root = cls.root();
ArgType superClass = cls.getSuperClass();
if (superClass != null && !Objects.equals(superClass, ArgType.OBJECT)) {
addSuperType(root, superTypes, superClass);
}
for (ArgType iface : cls.getInterfaces()) {
addSuperType(root, superTypes, iface);
}
}
private void addSuperType(RootNode root, Map<String, ArgType> superTypesMap, ArgType superType) {
superTypesMap.put(superType.getObject(), superType);
ClassNode classNode = root.resolveClass(superType);
if (classNode == null) {
for (String superCls : root.getClsp().getSuperTypes(superType.getObject())) {
ArgType type = ArgType.object(superCls);
superTypesMap.put(type.getObject(), type);
}
} else {
collectSuperTypes(classNode, superTypesMap);
}
}
private void fixMethodReturnType(MethodNode mth, List<IMethodDetails> overrideList, List<ArgType> superTypes) {
ArgType returnType = mth.getReturnType();
int updateCount = 0;
for (IMethodDetails baseMth : overrideList) {
if (updateReturnType(mth, baseMth, superTypes)) {
updateCount++;
}
}
if (updateCount == 0) {
return;
}
if (updateCount == 1) {
mth.addComment("Return type fixed from '" + returnType + "' to match base method");
} else {
mth.addWarnComment("Due to multiple override return type can be incorrect, original value: " + returnType);
}
}
private boolean updateReturnType(MethodNode mth, IMethodDetails baseMth, List<ArgType> superTypes) {
ArgType baseReturnType = baseMth.getReturnType();
if (mth.getReturnType().equals(baseReturnType)) {
return false;
}
if (!baseReturnType.containsTypeVariable()) {
return false;
}
TypeCompare typeCompare = mth.root().getTypeUpdate().getTypeCompare();
ArgType baseCls = baseMth.getMethodInfo().getDeclClass().getType();
for (ArgType superType : superTypes) {
TypeCompareEnum compareResult = typeCompare.compareTypes(superType, baseCls);
if (compareResult == TypeCompareEnum.NARROW_BY_GENERIC) {
ArgType targetRetType = mth.root().getTypeUtils().replaceClassGenerics(superType, baseReturnType);
if (targetRetType != null
&& !targetRetType.containsTypeVariable()
&& !targetRetType.equals(mth.getReturnType())) {
mth.updateReturnType(targetRetType);
return true;
}
}
}
return false;
}
}
@@ -521,7 +521,7 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
}
ClspGraph clsp = mth.root().getClsp();
for (ArgType objType : objTypes) {
for (String ancestor : clsp.getAncestors(objType.getObject())) {
for (String ancestor : clsp.getSuperTypes(objType.getObject())) {
ArgType ancestorType = ArgType.object(ancestor);
TypeUpdateResult result = typeUpdate.applyWithWiderAllow(var, ancestorType);
if (result == TypeUpdateResult.CHANGED) {
@@ -281,7 +281,7 @@ public class TypeSearch {
private List<ArgType> getWiderTypes(ArgType type) {
if (type.isTypeKnown()) {
if (type.isObject()) {
Set<String> ancestors = mth.root().getClsp().getAncestors(type.getObject());
Set<String> ancestors = mth.root().getClsp().getSuperTypes(type.getObject());
return ancestors.stream().map(ArgType::object).collect(Collectors.toList());
}
} else {
@@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.jar.JarOutputStream;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@@ -458,6 +459,13 @@ public abstract class IntegrationTest extends TestUtils {
return files;
}
@NotNull
protected static String removeLineComments(ClassNode cls) {
String code = cls.getCode().getCodeStr().replaceAll("\\W*//.*", "");
System.out.println(code);
return code;
}
public JadxArgs getArgs() {
return args;
}
@@ -14,11 +14,13 @@ public class TestEnums2 extends IntegrationTest {
public enum Operation {
PLUS {
@Override
public int apply(int x, int y) {
return x + y;
}
},
MINUS {
@Override
public int apply(int x, int y) {
return x - y;
}
@@ -31,16 +33,18 @@ public class TestEnums2 extends IntegrationTest {
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
String code = removeLineComments(cls);
assertThat(code, JadxMatchers.containsLines(1,
"public enum Operation {",
indent(1) + "PLUS {",
indent(2) + "@Override",
indent(2) + "public int apply(int x, int y) {",
indent(3) + "return x + y;",
indent(2) + '}',
indent(1) + "},",
indent(1) + "MINUS {",
indent(2) + "@Override",
indent(2) + "public int apply(int x, int y) {",
indent(3) + "return x - y;",
indent(2) + '}',
@@ -35,16 +35,18 @@ public class TestEnumsInterface extends IntegrationTest {
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
String code = removeLineComments(cls);
assertThat(code, JadxMatchers.containsLines(1,
"public enum Operation implements IOperation {",
indent(1) + "PLUS {",
indent(2) + "@Override",
indent(2) + "public int apply(int x, int y) {",
indent(3) + "return x + y;",
indent(2) + '}',
indent(1) + "},",
indent(1) + "MINUS {",
indent(2) + "@Override",
indent(2) + "public int apply(int x, int y) {",
indent(3) + "return x - y;",
indent(2) + '}',
@@ -6,6 +6,7 @@ import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static jadx.tests.api.utils.JadxMatchers.countString;
import static org.hamcrest.MatcherAssert.assertThat;
public class TestGenericsMthOverride extends IntegrationTest {
@@ -54,5 +55,9 @@ public class TestGenericsMthOverride extends IntegrationTest {
assertThat(code, containsOne("public Y method(Object x) {"));
assertThat(code, containsOne("public Y method(Exception x) {"));
assertThat(code, containsOne("public Object method(Object x) {"));
assertThat(code, countString(3, "@Override"));
// TODO: @Override missing for class C
// assertThat(code, countString(4, "@Override"));
}
}
@@ -0,0 +1,20 @@
package jadx.tests.integration.generics;
import org.junit.jupiter.api.Test;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestMethodOverride extends SmaliTest {
@Test
public void test() {
disableCompilation();
assertThat(getClassNodeFromSmali())
.code()
.containsOne("String createFromParcel(Parcel parcel) {")
.containsOne("@Override");
}
}
@@ -0,0 +1,38 @@
.class public final Lgenerics/TestMethodOverride;
.super Ljava/lang/Object;
# interfaces
.implements Landroid/os/Parcelable$Creator;
# annotations
.annotation system Ldalvik/annotation/Signature;
value = {
"Ljava/lang/Object;",
"Landroid/os/Parcelable$Creator<",
"Ljava/lang/String;",
">;"
}
.end annotation
# direct methods
.method public constructor <init>()V
.registers 1
.line 1
invoke-direct {p0}, Ljava/lang/Object;-><init>()V
return-void
.end method
# virtual methods
.method public final synthetic createFromParcel(Landroid/os/Parcel;)Ljava/lang/Object;
.registers 2
const/4 v0, 0x0
return-object v0
.end method