fix: collect class usage and fix class access modifiers (#729)

This commit is contained in:
Skylot
2020-05-21 21:56:58 +01:00
parent 0d69e0ac97
commit d720179deb
8 changed files with 193 additions and 61 deletions
@@ -98,6 +98,7 @@ public final class JadxDecompiler implements Closeable {
root.initClassPath();
root.loadResources(getResources());
root.initPasses();
root.runPreDecompileStage();
}
private void loadInputFiles() {
+7 -4
View File
@@ -76,6 +76,13 @@ public class Jadx {
return passes;
}
public static List<IDexTreeVisitor> getPreDecompilePassesList() {
List<IDexTreeVisitor> passes = new ArrayList<>();
passes.add(new RenameVisitor());
passes.add(new DependencyCollector());
return passes;
}
public static List<IDexTreeVisitor> getPassesList(JadxArgs args) {
if (args.isFallbackMode()) {
return getFallbackPassesList();
@@ -146,10 +153,6 @@ public class Jadx {
if (args.isCfgOutput()) {
passes.add(DotGraphVisitor.dumpRegions());
}
passes.add(new DependencyCollector());
passes.add(new RenameVisitor());
return passes;
}
@@ -54,8 +54,8 @@ public final class ProcessClass {
return generateCode(topParentClass);
}
try {
process(cls);
cls.getDependencies().forEach(ProcessClass::process);
process(cls);
ICodeInfo code = CodeGen.generate(cls);
cls.unload();
@@ -12,6 +12,7 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -43,7 +44,7 @@ import jadx.core.utils.exceptions.JadxRuntimeException;
import static jadx.core.dex.nodes.ProcessState.LOADED;
import static jadx.core.dex.nodes.ProcessState.NOT_LOADED;
public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeNode {
public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeNode, Comparable<ClassNode> {
private static final Logger LOG = LoggerFactory.getLogger(ClassNode.class);
private final RootNode root;
@@ -69,6 +70,7 @@ public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeN
private volatile ProcessState state = ProcessState.NOT_LOADED;
private List<ClassNode> dependencies = Collections.emptyList();
private List<ClassNode> usedIn = Collections.emptyList();
// cache maps
private Map<MethodInfo, MethodNode> mthInfoMap = Collections.emptyMap();
@@ -478,6 +480,10 @@ public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeN
return contains(AFlag.ANONYMOUS_CLASS);
}
public boolean isInner() {
return parentClass != null;
}
@Nullable
public MethodNode getClassInitMth() {
return searchMethodByShortId("<clinit>()V");
@@ -579,6 +585,14 @@ public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeN
this.dependencies = dependencies;
}
public List<ClassNode> getUsedIn() {
return usedIn;
}
public void setUsedIn(List<ClassNode> usedIn) {
this.usedIn = usedIn;
}
@Override
public Path getInputPath() {
return inputPath;
@@ -601,9 +615,13 @@ public class ClassNode extends NotificationAttrNode implements ILoadable, ICodeN
return false;
}
@Override
public int compareTo(@NotNull ClassNode o) {
return this.getFullName().compareTo(o.getFullName());
}
@Override
public String toString() {
return clsInfo.getFullName();
}
}
@@ -28,6 +28,7 @@ import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.utils.MethodUtils;
import jadx.core.dex.nodes.utils.TypeUtils;
import jadx.core.dex.visitors.DepthTraversal;
import jadx.core.dex.visitors.IDexTreeVisitor;
import jadx.core.dex.visitors.typeinference.TypeUpdate;
import jadx.core.utils.CacheStorage;
@@ -189,10 +190,27 @@ public class RootNode {
}
}
public void runPreDecompileStage() {
for (IDexTreeVisitor pass : Jadx.getPreDecompilePassesList()) {
try {
pass.init(this);
} catch (Exception e) {
LOG.error("Visitor init failed: {}", pass.getClass().getSimpleName(), e);
}
for (ClassNode cls : classes) {
DepthTraversal.visit(pass, cls);
}
}
}
public List<ClassNode> getClasses() {
return classes;
}
public List<ClassNode> getClassesWithoutInner() {
return getClasses(false);
}
public List<ClassNode> getClasses(boolean includeInner) {
if (includeInner) {
return classes;
@@ -1,33 +1,61 @@
package jadx.core.dex.visitors;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import jadx.api.plugins.input.data.ICodeReader;
import jadx.api.plugins.input.data.IFieldData;
import jadx.api.plugins.input.data.IMethodData;
import jadx.api.plugins.input.insns.InsnData;
import jadx.api.plugins.input.insns.Opcode;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.FieldInitAttr;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.instructions.BaseInvokeNode;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
name = "DependencyCollector",
desc = "Scan class and methods and collect dependant classes",
runAfter = {
RenameVisitor.class // sort by alias name
}
)
// TODO: store usage info for fields, methods and inner classes
public class DependencyCollector extends AbstractVisitor {
@Override
public boolean visit(ClassNode cls) throws JadxException {
public void init(RootNode root) {
List<ClassNode> clsList = root.getClassesWithoutInner();
for (ClassNode cls : clsList) {
collectClassDeps(cls);
}
buildUsageList(clsList);
}
private void buildUsageList(List<ClassNode> clsList) {
clsList.forEach(cls -> cls.setUsedIn(new ArrayList<>()));
for (ClassNode cls : clsList) {
for (ClassNode depCls : cls.getDependencies()) {
depCls.getUsedIn().add(cls);
}
}
for (ClassNode cls : clsList) {
List<ClassNode> usedIn = cls.getUsedIn();
if (usedIn.isEmpty()) {
cls.setUsedIn(Collections.emptyList());
} else {
Collections.sort(usedIn);
}
}
}
public void collectClassDeps(ClassNode cls) {
RootNode root = cls.root();
Set<ClassNode> depSet = new HashSet<>();
processClass(cls, root, depSet);
@@ -36,10 +64,13 @@ public class DependencyCollector extends AbstractVisitor {
}
depSet.remove(cls);
List<ClassNode> depList = new ArrayList<>(depSet);
depList.sort(Comparator.comparing(c -> c.getClassInfo().getFullName()));
cls.setDependencies(depList);
return false;
if (depSet.isEmpty()) {
cls.setDependencies(Collections.emptyList());
} else {
List<ClassNode> depList = new ArrayList<>(depSet);
Collections.sort(depList);
cls.setDependencies(depList);
}
}
private static void processClass(ClassNode cls, RootNode root, Set<ClassNode> depList) {
@@ -49,12 +80,6 @@ public class DependencyCollector extends AbstractVisitor {
}
for (FieldNode fieldNode : cls.getFields()) {
addDep(root, depList, fieldNode.getType());
// process instructions from field init
FieldInitAttr fieldInitAttr = fieldNode.get(AType.FIELD_INIT);
if (fieldInitAttr != null && fieldInitAttr.getValueType() == FieldInitAttr.InitType.INSN) {
processInsn(root, depList, fieldInitAttr.getInsn());
}
}
// TODO: process annotations and generics
for (MethodNode methodNode : cls.getMethods()) {
@@ -71,41 +96,62 @@ public class DependencyCollector extends AbstractVisitor {
for (ArgType arg : methodNode.getMethodInfo().getArgumentsTypes()) {
addDep(root, depList, arg);
}
for (BlockNode block : methodNode.getBasicBlocks()) {
for (InsnNode insnNode : block.getInstructions()) {
processInsn(root, depList, insnNode);
}
try {
processInstructions(methodNode, depList);
} catch (Exception e) {
methodNode.getCodeReader().visitInstructions(insnData -> {
insnData.decode();
System.out.println(insnData);
});
methodNode.addError("Dependency scan failed", e);
}
}
// TODO: add custom instructions processing
private static void processInsn(RootNode root, Set<ClassNode> depList, InsnNode insnNode) {
RegisterArg result = insnNode.getResult();
if (result != null) {
addDep(root, depList, result.getType());
private static void processInstructions(MethodNode mth, Set<ClassNode> deps) {
ICodeReader codeReader = mth.getCodeReader();
if (codeReader == null) {
return;
}
for (InsnArg arg : insnNode.getArguments()) {
if (arg.isInsnWrap()) {
processInsn(root, depList, ((InsnWrapArg) arg).getWrapInsn());
} else {
addDep(root, depList, arg.getType());
RootNode root = mth.root();
codeReader.visitInstructions(insnData -> {
try {
processInsn(root, insnData, deps);
} catch (Exception e) {
mth.addError("Dependency scan failed at insn: " + insnData, e);
}
}
processCustomInsn(root, depList, insnNode);
});
}
private static void processCustomInsn(RootNode root, Set<ClassNode> depList, InsnNode insn) {
if (insn instanceof IndexInsnNode) {
Object index = ((IndexInsnNode) insn).getIndex();
if (index instanceof FieldInfo) {
addDep(root, depList, ((FieldInfo) index).getDeclClass());
} else if (index instanceof ArgType) {
addDep(root, depList, (ArgType) index);
}
} else if (insn instanceof BaseInvokeNode) {
ClassInfo declClass = ((BaseInvokeNode) insn).getCallMth().getDeclClass();
addDep(root, depList, declClass);
private static void processInsn(RootNode root, InsnData insnData, Set<ClassNode> deps) {
if (insnData.getOpcode() == Opcode.UNKNOWN) {
return;
}
switch (insnData.getIndexType()) {
case TYPE_REF:
insnData.decode();
resolveType(root, deps, insnData.getIndexAsType());
break;
case FIELD_REF:
insnData.decode();
resolveField(root, deps, insnData.getIndexAsField());
break;
case METHOD_REF:
insnData.decode();
resolveMethod(root, deps, insnData.getIndexAsMethod());
break;
}
}
private static void resolveType(RootNode root, Set<ClassNode> deps, String type) {
addDep(root, deps, ArgType.parse(type));
}
private static void resolveMethod(RootNode root, Set<ClassNode> deps, IMethodData method) {
resolveType(root, deps, method.getParentClassType());
}
private static void resolveField(RootNode root, Set<ClassNode> deps, IFieldData field) {
resolveType(root, deps, field.getParentClassType());
}
private static void addDep(RootNode root, Set<ClassNode> depList, ArgType type) {
@@ -3,9 +3,11 @@ package jadx.core.dex.visitors;
import jadx.api.plugins.input.data.AccessFlags;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.ICodeNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
name = "FixAccessModifiers",
@@ -21,12 +23,24 @@ public class FixAccessModifiers extends AbstractVisitor {
this.respectAccessModifiers = root.getArgs().isRespectBytecodeAccModifiers();
}
@Override
public boolean visit(ClassNode cls) throws JadxException {
if (respectAccessModifiers) {
return true;
}
int newVisFlag = fixClassVisibility(cls);
if (newVisFlag != -1) {
changeVisibility(cls, newVisFlag);
}
return true;
}
@Override
public void visit(MethodNode mth) {
if (respectAccessModifiers) {
return;
}
int newVisFlag = fixVisibility(mth);
int newVisFlag = fixMethodVisibility(mth);
if (newVisFlag != -1) {
changeVisibility(mth, newVisFlag);
}
@@ -41,7 +55,35 @@ public class FixAccessModifiers extends AbstractVisitor {
}
}
private static int fixVisibility(MethodNode mth) {
private int fixClassVisibility(ClassNode cls) {
if (cls.getUsedIn().isEmpty()) {
return -1;
}
AccessInfo accessFlags = cls.getAccessFlags();
if (accessFlags.isPrivate()) {
if (!cls.isInner()) {
return AccessFlags.PUBLIC;
}
// check if private inner class is used outside
ClassNode topParentClass = cls.getTopParentClass();
for (ClassNode useCls : cls.getUsedIn()) {
if (useCls.getTopParentClass() != topParentClass) {
return AccessFlags.PUBLIC;
}
}
}
if (accessFlags.isPackagePrivate()) {
String pkg = cls.getPackage();
for (ClassNode useCls : cls.getUsedIn()) {
if (!useCls.getPackage().equals(pkg)) {
return AccessFlags.PUBLIC;
}
}
}
return -1;
}
private static int fixMethodVisibility(MethodNode mth) {
if (mth.isVirtual()) {
// make virtual methods public
return AccessFlags.PUBLIC;
@@ -18,13 +18,16 @@ import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.empty;
public class JadxVisitorsOrderTest {
private static final Logger LOG = LoggerFactory.getLogger(JadxVisitorsOrderTest.class);
@Test
public void testOrder() {
List<IDexTreeVisitor> passes = Jadx.getPassesList(new JadxArgs());
checkPassList(Jadx.getPassesList(new JadxArgs()));
checkPassList(Jadx.getPreDecompilePassesList());
checkPassList(Jadx.getFallbackPassesList());
}
private void checkPassList(List<IDexTreeVisitor> passes) {
List<String> errors = check(passes);
for (String str : errors) {
LOG.error(str);
@@ -55,7 +58,8 @@ public class JadxVisitorsOrderTest {
errors.add("Visitor name conflict: " + passName + ", class: " + passClass.getName());
}
for (Class<? extends IDexTreeVisitor> cls : info.runBefore()) {
if (classList.indexOf(cls) < i) {
int beforeIndex = classList.indexOf(cls);
if (beforeIndex != -1 && beforeIndex < i) {
errors.add("Pass " + passName + " must be before " + cls.getSimpleName());
}
}