From d720179debbb27c7deadbfcb231bcade9176f966 Mon Sep 17 00:00:00 2001 From: Skylot Date: Thu, 21 May 2020 21:56:58 +0100 Subject: [PATCH] fix: collect class usage and fix class access modifiers (#729) --- .../main/java/jadx/api/JadxDecompiler.java | 1 + jadx-core/src/main/java/jadx/core/Jadx.java | 11 +- .../src/main/java/jadx/core/ProcessClass.java | 2 +- .../java/jadx/core/dex/nodes/ClassNode.java | 22 ++- .../java/jadx/core/dex/nodes/RootNode.java | 18 +++ .../dex/visitors/DependencyCollector.java | 144 ++++++++++++------ .../core/dex/visitors/FixAccessModifiers.java | 46 +++++- .../functional/JadxVisitorsOrderTest.java | 10 +- 8 files changed, 193 insertions(+), 61 deletions(-) diff --git a/jadx-core/src/main/java/jadx/api/JadxDecompiler.java b/jadx-core/src/main/java/jadx/api/JadxDecompiler.java index 65e1c8bd9..4e78b4308 100644 --- a/jadx-core/src/main/java/jadx/api/JadxDecompiler.java +++ b/jadx-core/src/main/java/jadx/api/JadxDecompiler.java @@ -98,6 +98,7 @@ public final class JadxDecompiler implements Closeable { root.initClassPath(); root.loadResources(getResources()); root.initPasses(); + root.runPreDecompileStage(); } private void loadInputFiles() { diff --git a/jadx-core/src/main/java/jadx/core/Jadx.java b/jadx-core/src/main/java/jadx/core/Jadx.java index 51846927a..84889304f 100644 --- a/jadx-core/src/main/java/jadx/core/Jadx.java +++ b/jadx-core/src/main/java/jadx/core/Jadx.java @@ -76,6 +76,13 @@ public class Jadx { return passes; } + public static List getPreDecompilePassesList() { + List passes = new ArrayList<>(); + passes.add(new RenameVisitor()); + passes.add(new DependencyCollector()); + return passes; + } + public static List getPassesList(JadxArgs args) { if (args.isFallbackMode()) { return getFallbackPassesList(); @@ -146,10 +153,6 @@ public class Jadx { if (args.isCfgOutput()) { passes.add(DotGraphVisitor.dumpRegions()); } - - passes.add(new DependencyCollector()); - passes.add(new RenameVisitor()); - return passes; } diff --git a/jadx-core/src/main/java/jadx/core/ProcessClass.java b/jadx-core/src/main/java/jadx/core/ProcessClass.java index 35e82ebc3..5e0bad411 100644 --- a/jadx-core/src/main/java/jadx/core/ProcessClass.java +++ b/jadx-core/src/main/java/jadx/core/ProcessClass.java @@ -54,8 +54,8 @@ public final class ProcessClass { return generateCode(topParentClass); } try { - process(cls); cls.getDependencies().forEach(ProcessClass::process); + process(cls); ICodeInfo code = CodeGen.generate(cls); cls.unload(); diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java b/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java index ca1516337..0a7ad5321 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java @@ -12,6 +12,7 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,7 +44,7 @@ import jadx.core.utils.exceptions.JadxRuntimeException; import static jadx.core.dex.nodes.ProcessState.LOADED; import static jadx.core.dex.nodes.ProcessState.NOT_LOADED; -public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeNode { +public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeNode, Comparable { private static final Logger LOG = LoggerFactory.getLogger(ClassNode.class); private final RootNode root; @@ -69,6 +70,7 @@ public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeN private volatile ProcessState state = ProcessState.NOT_LOADED; private List dependencies = Collections.emptyList(); + private List usedIn = Collections.emptyList(); // cache maps private Map mthInfoMap = Collections.emptyMap(); @@ -478,6 +480,10 @@ public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeN return contains(AFlag.ANONYMOUS_CLASS); } + public boolean isInner() { + return parentClass != null; + } + @Nullable public MethodNode getClassInitMth() { return searchMethodByShortId("()V"); @@ -579,6 +585,14 @@ public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeN this.dependencies = dependencies; } + public List getUsedIn() { + return usedIn; + } + + public void setUsedIn(List usedIn) { + this.usedIn = usedIn; + } + @Override public Path getInputPath() { return inputPath; @@ -601,9 +615,13 @@ public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeN return false; } + @Override + public int compareTo(@NotNull ClassNode o) { + return this.getFullName().compareTo(o.getFullName()); + } + @Override public String toString() { return clsInfo.getFullName(); } - } diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/RootNode.java b/jadx-core/src/main/java/jadx/core/dex/nodes/RootNode.java index 10ddbdac1..ee95b0297 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/RootNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/RootNode.java @@ -28,6 +28,7 @@ import jadx.core.dex.info.MethodInfo; import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.nodes.utils.MethodUtils; import jadx.core.dex.nodes.utils.TypeUtils; +import jadx.core.dex.visitors.DepthTraversal; import jadx.core.dex.visitors.IDexTreeVisitor; import jadx.core.dex.visitors.typeinference.TypeUpdate; import jadx.core.utils.CacheStorage; @@ -189,10 +190,27 @@ public class RootNode { } } + public void runPreDecompileStage() { + for (IDexTreeVisitor pass : Jadx.getPreDecompilePassesList()) { + try { + pass.init(this); + } catch (Exception e) { + LOG.error("Visitor init failed: {}", pass.getClass().getSimpleName(), e); + } + for (ClassNode cls : classes) { + DepthTraversal.visit(pass, cls); + } + } + } + public List getClasses() { return classes; } + public List getClassesWithoutInner() { + return getClasses(false); + } + public List getClasses(boolean includeInner) { if (includeInner) { return classes; diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/DependencyCollector.java b/jadx-core/src/main/java/jadx/core/dex/visitors/DependencyCollector.java index 0e4b4b4c7..9c39c31f5 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/DependencyCollector.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/DependencyCollector.java @@ -1,33 +1,61 @@ package jadx.core.dex.visitors; import java.util.ArrayList; -import java.util.Comparator; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; +import jadx.api.plugins.input.data.ICodeReader; +import jadx.api.plugins.input.data.IFieldData; +import jadx.api.plugins.input.data.IMethodData; +import jadx.api.plugins.input.insns.InsnData; +import jadx.api.plugins.input.insns.Opcode; import jadx.core.dex.attributes.AType; -import jadx.core.dex.attributes.FieldInitAttr; import jadx.core.dex.info.ClassInfo; -import jadx.core.dex.info.FieldInfo; -import jadx.core.dex.instructions.BaseInvokeNode; -import jadx.core.dex.instructions.IndexInsnNode; import jadx.core.dex.instructions.args.ArgType; -import jadx.core.dex.instructions.args.InsnArg; -import jadx.core.dex.instructions.args.InsnWrapArg; -import jadx.core.dex.instructions.args.RegisterArg; -import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.FieldNode; -import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.RootNode; -import jadx.core.utils.exceptions.JadxException; +@JadxVisitor( + name = "DependencyCollector", + desc = "Scan class and methods and collect dependant classes", + runAfter = { + RenameVisitor.class // sort by alias name + } +) +// TODO: store usage info for fields, methods and inner classes public class DependencyCollector extends AbstractVisitor { @Override - public boolean visit(ClassNode cls) throws JadxException { + public void init(RootNode root) { + List clsList = root.getClassesWithoutInner(); + for (ClassNode cls : clsList) { + collectClassDeps(cls); + } + buildUsageList(clsList); + } + + private void buildUsageList(List clsList) { + clsList.forEach(cls -> cls.setUsedIn(new ArrayList<>())); + for (ClassNode cls : clsList) { + for (ClassNode depCls : cls.getDependencies()) { + depCls.getUsedIn().add(cls); + } + } + for (ClassNode cls : clsList) { + List usedIn = cls.getUsedIn(); + if (usedIn.isEmpty()) { + cls.setUsedIn(Collections.emptyList()); + } else { + Collections.sort(usedIn); + } + } + } + + public void collectClassDeps(ClassNode cls) { RootNode root = cls.root(); Set depSet = new HashSet<>(); processClass(cls, root, depSet); @@ -36,10 +64,13 @@ public class DependencyCollector extends AbstractVisitor { } depSet.remove(cls); - List depList = new ArrayList<>(depSet); - depList.sort(Comparator.comparing(c -> c.getClassInfo().getFullName())); - cls.setDependencies(depList); - return false; + if (depSet.isEmpty()) { + cls.setDependencies(Collections.emptyList()); + } else { + List depList = new ArrayList<>(depSet); + Collections.sort(depList); + cls.setDependencies(depList); + } } private static void processClass(ClassNode cls, RootNode root, Set depList) { @@ -49,12 +80,6 @@ public class DependencyCollector extends AbstractVisitor { } for (FieldNode fieldNode : cls.getFields()) { addDep(root, depList, fieldNode.getType()); - - // process instructions from field init - FieldInitAttr fieldInitAttr = fieldNode.get(AType.FIELD_INIT); - if (fieldInitAttr != null && fieldInitAttr.getValueType() == FieldInitAttr.InitType.INSN) { - processInsn(root, depList, fieldInitAttr.getInsn()); - } } // TODO: process annotations and generics for (MethodNode methodNode : cls.getMethods()) { @@ -71,41 +96,62 @@ public class DependencyCollector extends AbstractVisitor { for (ArgType arg : methodNode.getMethodInfo().getArgumentsTypes()) { addDep(root, depList, arg); } - for (BlockNode block : methodNode.getBasicBlocks()) { - for (InsnNode insnNode : block.getInstructions()) { - processInsn(root, depList, insnNode); - } + try { + processInstructions(methodNode, depList); + } catch (Exception e) { + methodNode.getCodeReader().visitInstructions(insnData -> { + insnData.decode(); + System.out.println(insnData); + }); + methodNode.addError("Dependency scan failed", e); } } - // TODO: add custom instructions processing - private static void processInsn(RootNode root, Set depList, InsnNode insnNode) { - RegisterArg result = insnNode.getResult(); - if (result != null) { - addDep(root, depList, result.getType()); + private static void processInstructions(MethodNode mth, Set deps) { + ICodeReader codeReader = mth.getCodeReader(); + if (codeReader == null) { + return; } - for (InsnArg arg : insnNode.getArguments()) { - if (arg.isInsnWrap()) { - processInsn(root, depList, ((InsnWrapArg) arg).getWrapInsn()); - } else { - addDep(root, depList, arg.getType()); + RootNode root = mth.root(); + codeReader.visitInstructions(insnData -> { + try { + processInsn(root, insnData, deps); + } catch (Exception e) { + mth.addError("Dependency scan failed at insn: " + insnData, e); } - } - processCustomInsn(root, depList, insnNode); + }); } - private static void processCustomInsn(RootNode root, Set depList, InsnNode insn) { - if (insn instanceof IndexInsnNode) { - Object index = ((IndexInsnNode) insn).getIndex(); - if (index instanceof FieldInfo) { - addDep(root, depList, ((FieldInfo) index).getDeclClass()); - } else if (index instanceof ArgType) { - addDep(root, depList, (ArgType) index); - } - } else if (insn instanceof BaseInvokeNode) { - ClassInfo declClass = ((BaseInvokeNode) insn).getCallMth().getDeclClass(); - addDep(root, depList, declClass); + private static void processInsn(RootNode root, InsnData insnData, Set deps) { + if (insnData.getOpcode() == Opcode.UNKNOWN) { + return; } + switch (insnData.getIndexType()) { + case TYPE_REF: + insnData.decode(); + resolveType(root, deps, insnData.getIndexAsType()); + break; + case FIELD_REF: + insnData.decode(); + resolveField(root, deps, insnData.getIndexAsField()); + break; + case METHOD_REF: + insnData.decode(); + resolveMethod(root, deps, insnData.getIndexAsMethod()); + break; + } + } + + private static void resolveType(RootNode root, Set deps, String type) { + addDep(root, deps, ArgType.parse(type)); + } + + private static void resolveMethod(RootNode root, Set deps, IMethodData method) { + resolveType(root, deps, method.getParentClassType()); + } + + private static void resolveField(RootNode root, Set deps, IFieldData field) { + resolveType(root, deps, field.getParentClassType()); } private static void addDep(RootNode root, Set depList, ArgType type) { diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/FixAccessModifiers.java b/jadx-core/src/main/java/jadx/core/dex/visitors/FixAccessModifiers.java index 322350f18..a7166be8a 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/FixAccessModifiers.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/FixAccessModifiers.java @@ -3,9 +3,11 @@ package jadx.core.dex.visitors; import jadx.api.plugins.input.data.AccessFlags; import jadx.core.dex.attributes.AType; import jadx.core.dex.info.AccessInfo; +import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ICodeNode; import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.RootNode; +import jadx.core.utils.exceptions.JadxException; @JadxVisitor( name = "FixAccessModifiers", @@ -21,12 +23,24 @@ public class FixAccessModifiers extends AbstractVisitor { this.respectAccessModifiers = root.getArgs().isRespectBytecodeAccModifiers(); } + @Override + public boolean visit(ClassNode cls) throws JadxException { + if (respectAccessModifiers) { + return true; + } + int newVisFlag = fixClassVisibility(cls); + if (newVisFlag != -1) { + changeVisibility(cls, newVisFlag); + } + return true; + } + @Override public void visit(MethodNode mth) { if (respectAccessModifiers) { return; } - int newVisFlag = fixVisibility(mth); + int newVisFlag = fixMethodVisibility(mth); if (newVisFlag != -1) { changeVisibility(mth, newVisFlag); } @@ -41,7 +55,35 @@ public class FixAccessModifiers extends AbstractVisitor { } } - private static int fixVisibility(MethodNode mth) { + private int fixClassVisibility(ClassNode cls) { + if (cls.getUsedIn().isEmpty()) { + return -1; + } + AccessInfo accessFlags = cls.getAccessFlags(); + if (accessFlags.isPrivate()) { + if (!cls.isInner()) { + return AccessFlags.PUBLIC; + } + // check if private inner class is used outside + ClassNode topParentClass = cls.getTopParentClass(); + for (ClassNode useCls : cls.getUsedIn()) { + if (useCls.getTopParentClass() != topParentClass) { + return AccessFlags.PUBLIC; + } + } + } + if (accessFlags.isPackagePrivate()) { + String pkg = cls.getPackage(); + for (ClassNode useCls : cls.getUsedIn()) { + if (!useCls.getPackage().equals(pkg)) { + return AccessFlags.PUBLIC; + } + } + } + return -1; + } + + private static int fixMethodVisibility(MethodNode mth) { if (mth.isVirtual()) { // make virtual methods public return AccessFlags.PUBLIC; diff --git a/jadx-core/src/test/java/jadx/tests/functional/JadxVisitorsOrderTest.java b/jadx-core/src/test/java/jadx/tests/functional/JadxVisitorsOrderTest.java index afeeb9b01..b6b15e7a7 100644 --- a/jadx-core/src/test/java/jadx/tests/functional/JadxVisitorsOrderTest.java +++ b/jadx-core/src/test/java/jadx/tests/functional/JadxVisitorsOrderTest.java @@ -18,13 +18,16 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.empty; public class JadxVisitorsOrderTest { - private static final Logger LOG = LoggerFactory.getLogger(JadxVisitorsOrderTest.class); @Test public void testOrder() { - List passes = Jadx.getPassesList(new JadxArgs()); + checkPassList(Jadx.getPassesList(new JadxArgs())); + checkPassList(Jadx.getPreDecompilePassesList()); + checkPassList(Jadx.getFallbackPassesList()); + } + private void checkPassList(List passes) { List errors = check(passes); for (String str : errors) { LOG.error(str); @@ -55,7 +58,8 @@ public class JadxVisitorsOrderTest { errors.add("Visitor name conflict: " + passName + ", class: " + passClass.getName()); } for (Class cls : info.runBefore()) { - if (classList.indexOf(cls) < i) { + int beforeIndex = classList.indexOf(cls); + if (beforeIndex != -1 && beforeIndex < i) { errors.add("Pass " + passName + " must be before " + cls.getSimpleName()); } }