diff --git a/jadx-core/src/main/java/jadx/api/JadxDecompiler.java b/jadx-core/src/main/java/jadx/api/JadxDecompiler.java index 215f0106b..4ca555914 100644 --- a/jadx-core/src/main/java/jadx/api/JadxDecompiler.java +++ b/jadx-core/src/main/java/jadx/api/JadxDecompiler.java @@ -2,6 +2,7 @@ package jadx.api; import jadx.core.Jadx; import jadx.core.ProcessClass; +import jadx.core.codegen.CodeGen; import jadx.core.codegen.CodeWriter; import jadx.core.deobf.DefaultDeobfuscator; import jadx.core.deobf.Deobfuscator; @@ -57,6 +58,8 @@ public final class JadxDecompiler { private RootNode root; private List passes; + private CodeGen codeGen; + private List classes; private List resources; @@ -83,6 +86,7 @@ public final class JadxDecompiler { outDir = new DefaultJadxArgs().getOutDir(); } this.passes = Jadx.getPassesList(args, outDir); + this.codeGen = new CodeGen(args); } void reset() { @@ -305,7 +309,7 @@ public final class JadxDecompiler { } void processClass(ClassNode cls) { - ProcessClass.process(cls, passes); + ProcessClass.process(cls, passes, codeGen); } RootNode getRoot() { @@ -331,6 +335,10 @@ public final class JadxDecompiler { return null; } + public IJadxArgs getArgs() { + return args; + } + @Override public String toString() { return "jadx decompiler " + getVersion(); diff --git a/jadx-core/src/main/java/jadx/core/Jadx.java b/jadx-core/src/main/java/jadx/core/Jadx.java index e79e1de9a..b083b25d7 100644 --- a/jadx-core/src/main/java/jadx/core/Jadx.java +++ b/jadx-core/src/main/java/jadx/core/Jadx.java @@ -1,11 +1,11 @@ package jadx.core; import jadx.api.IJadxArgs; -import jadx.core.codegen.CodeGen; import jadx.core.dex.visitors.ClassModifier; import jadx.core.dex.visitors.CodeShrinker; import jadx.core.dex.visitors.ConstInlineVisitor; import jadx.core.dex.visitors.DebugInfoVisitor; +import jadx.core.dex.visitors.DependencyCollector; import jadx.core.dex.visitors.DotGraphVisitor; import jadx.core.dex.visitors.EnumVisitor; import jadx.core.dex.visitors.FallbackModeVisitor; @@ -104,8 +104,9 @@ public class Jadx { passes.add(new PrepareForCodeGen()); passes.add(new LoopRegionVisitor()); passes.add(new ProcessVariables()); + + passes.add(new DependencyCollector()); } - passes.add(new CodeGen(args)); return passes; } diff --git a/jadx-core/src/main/java/jadx/core/ProcessClass.java b/jadx-core/src/main/java/jadx/core/ProcessClass.java index 2a47a87ef..08f686f46 100644 --- a/jadx-core/src/main/java/jadx/core/ProcessClass.java +++ b/jadx-core/src/main/java/jadx/core/ProcessClass.java @@ -1,30 +1,55 @@ package jadx.core; +import jadx.core.codegen.CodeGen; import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.visitors.DepthTraversal; import jadx.core.dex.visitors.IDexTreeVisitor; +import jadx.core.utils.ErrorsCounter; import java.util.List; +import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static jadx.core.dex.nodes.ProcessState.GENERATED; +import static jadx.core.dex.nodes.ProcessState.NOT_LOADED; +import static jadx.core.dex.nodes.ProcessState.PROCESSED; +import static jadx.core.dex.nodes.ProcessState.STARTED; +import static jadx.core.dex.nodes.ProcessState.UNLOADED; + public final class ProcessClass { private static final Logger LOG = LoggerFactory.getLogger(ProcessClass.class); private ProcessClass() { } - public static void process(ClassNode cls, List passes) { - try { - cls.load(); - for (IDexTreeVisitor visitor : passes) { - DepthTraversal.visit(visitor, cls); + public static void process(ClassNode cls, List passes, @Nullable CodeGen codeGen) { + synchronized (cls) { + try { + if (cls.getState() == NOT_LOADED) { + cls.load(); + cls.setState(STARTED); + for (IDexTreeVisitor visitor : passes) { + DepthTraversal.visit(visitor, cls); + } + for (ClassNode clsNode : cls.getDependencies()) { + process(clsNode, passes, null); + } + cls.setState(PROCESSED); + } + if (cls.getState() == PROCESSED && codeGen != null) { + codeGen.visit(cls); + cls.setState(GENERATED); + } + } catch (Exception e) { + ErrorsCounter.classError(cls, e.getClass().getSimpleName(), e); + } finally { + if (cls.getState() == GENERATED) { + cls.unload(); + cls.setState(UNLOADED); + } } - } catch (Exception e) { - LOG.error("Class process exception: {}", cls, e); - } finally { - cls.unload(); } } } diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java b/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java index 030196c56..464cbc30d 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java @@ -24,9 +24,11 @@ import jadx.core.utils.exceptions.JadxRuntimeException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.jetbrains.annotations.TestOnly; import org.slf4j.Logger; @@ -58,6 +60,9 @@ public class ClassNode extends LineAttrNode implements ILoadable { // store parent for inner classes or 'this' otherwise private ClassNode parentClass; + private ProcessState state = ProcessState.NOT_LOADED; + private final Set dependencies = new HashSet(); + public ClassNode(DexNode dex, ClassDef cls) throws DecodeException { this.dex = dex; this.clsInfo = ClassInfo.fromDex(dex, cls.getTypeIndex()); @@ -452,6 +457,18 @@ public class ClassNode extends LineAttrNode implements ILoadable { return code; } + public ProcessState getState() { + return state; + } + + public void setState(ProcessState state) { + this.state = state; + } + + public Set getDependencies() { + return dependencies; + } + @Override public String toString() { return getFullName(); diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/ProcessState.java b/jadx-core/src/main/java/jadx/core/dex/nodes/ProcessState.java new file mode 100644 index 000000000..18b4d4d1d --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/ProcessState.java @@ -0,0 +1,9 @@ +package jadx.core.dex.nodes; + +public enum ProcessState { + NOT_LOADED, + STARTED, + PROCESSED, + GENERATED, + UNLOADED +} diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/DependencyCollector.java b/jadx-core/src/main/java/jadx/core/dex/visitors/DependencyCollector.java new file mode 100644 index 000000000..eca74014a --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/DependencyCollector.java @@ -0,0 +1,107 @@ +package jadx.core.dex.visitors; + +import jadx.core.dex.info.ClassInfo; +import jadx.core.dex.instructions.args.ArgType; +import jadx.core.dex.instructions.args.InsnArg; +import jadx.core.dex.instructions.args.InsnWrapArg; +import jadx.core.dex.instructions.args.RegisterArg; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.ClassNode; +import jadx.core.dex.nodes.DexNode; +import jadx.core.dex.nodes.FieldNode; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.utils.exceptions.JadxException; + +import java.util.Set; + +public class DependencyCollector extends AbstractVisitor { + + @Override + public boolean visit(ClassNode cls) throws JadxException { + DexNode dex = cls.dex(); + Set depList = cls.getDependencies(); + processClass(cls, dex, depList); + for (ClassNode inner : cls.getInnerClasses()) { + processClass(inner, dex, depList); + } + depList.remove(cls); + return false; + } + + private static void processClass(ClassNode cls, DexNode dex, Set depList) { + addDep(dex, depList, cls.getSuperClass()); + for (ClassInfo clsInfo : cls.getInterfaces()) { + addDep(dex, depList, clsInfo); + } + for (FieldNode fieldNode : cls.getFields()) { + addDep(dex, depList, fieldNode.getType()); + } + // TODO: process annotations and generics + for (MethodNode methodNode : cls.getMethods()) { + if (methodNode.isNoCode()) { + continue; + } + processMethod(dex, depList, methodNode); + } + } + + private static void processMethod(DexNode dex, Set depList, MethodNode methodNode) { + addDep(dex, depList, methodNode.getParentClass()); + addDep(dex, depList, methodNode.getReturnType()); + for (ArgType arg : methodNode.getMethodInfo().getArgumentsTypes()) { + addDep(dex, depList, arg); + } + for (BlockNode block : methodNode.getBasicBlocks()) { + for (InsnNode insnNode : block.getInstructions()) { + processInsn(dex, depList, insnNode); + } + } + } + + // TODO: add custom instructions processing + private static void processInsn(DexNode dex, Set depList, InsnNode insnNode) { + RegisterArg result = insnNode.getResult(); + if (result != null) { + addDep(dex, depList, result.getType()); + } + for (InsnArg arg : insnNode.getArguments()) { + if (arg.isInsnWrap()) { + processInsn(dex, depList, ((InsnWrapArg) arg).getWrapInsn()); + } else { + addDep(dex, depList, arg.getType()); + } + } + } + + private static void addDep(DexNode dex, Set depList, ArgType type) { + if (type != null) { + if (type.isObject()) { + addDep(dex, depList, ClassInfo.fromName(type.getObject())); + ArgType[] genericTypes = type.getGenericTypes(); + if (type.isGeneric() && genericTypes != null) { + for (ArgType argType : genericTypes) { + addDep(dex, depList, argType); + } + } + } else if (type.isArray()) { + addDep(dex, depList, type.getArrayRootElement()); + } + } + } + + private static void addDep(DexNode dex, Set depList, ClassInfo clsInfo) { + if (clsInfo != null) { + ClassNode node = dex.resolveClass(clsInfo); + if (node != null) { + depList.add(node); + } + } + } + + private static void addDep(DexNode dex, Set depList, ClassNode clsNode) { + if (clsNode != null) { + depList.add(clsNode); + } + } +} diff --git a/jadx-core/src/test/java/jadx/tests/api/IntegrationTest.java b/jadx-core/src/test/java/jadx/tests/api/IntegrationTest.java index f490ce08c..20e344bd3 100644 --- a/jadx-core/src/test/java/jadx/tests/api/IntegrationTest.java +++ b/jadx-core/src/test/java/jadx/tests/api/IntegrationTest.java @@ -4,6 +4,8 @@ import jadx.api.DefaultJadxArgs; import jadx.api.JadxDecompiler; import jadx.api.JadxInternalAccess; import jadx.core.Jadx; +import jadx.core.ProcessClass; +import jadx.core.codegen.CodeGen; import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AType; import jadx.core.dex.nodes.ClassNode; @@ -11,6 +13,7 @@ import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.RootNode; import jadx.core.dex.visitors.DepthTraversal; import jadx.core.dex.visitors.IDexTreeVisitor; +import jadx.core.utils.exceptions.CodegenException; import jadx.core.utils.exceptions.JadxException; import jadx.core.utils.files.FileUtils; import jadx.tests.api.compiler.DynamicCompiler; @@ -51,6 +54,7 @@ public abstract class IntegrationTest extends TestUtils { protected boolean isFallback = false; protected boolean deleteTmpFiles = true; protected boolean withDebugInfo = true; + protected boolean unloadCls = true; protected Map resMap = Collections.emptyMap(); @@ -64,16 +68,18 @@ public abstract class IntegrationTest extends TestUtils { File jar = getJarForClass(clazz); return getClassNodeFromFile(jar, clazz.getName()); } catch (Exception e) { + e.printStackTrace(); fail(e.getMessage()); } return null; } public ClassNode getClassNodeFromFile(File file, String clsName) { - JadxDecompiler d = new JadxDecompiler(); + JadxDecompiler d = new JadxDecompiler(getArgs()); try { d.loadFile(file); } catch (JadxException e) { + e.printStackTrace(); fail(e.getMessage()); } RootNode root = JadxInternalAccess.getRoot(d); @@ -83,11 +89,11 @@ public abstract class IntegrationTest extends TestUtils { assertNotNull("Class not found: " + clsName, cls); assertEquals(cls.getFullName(), clsName); - cls.load(); - for (IDexTreeVisitor visitor : getPasses()) { - DepthTraversal.visit(visitor, cls); + if (unloadCls) { + decompile(d, cls); + } else { + decompileWithoutUnload(d, cls); } - // don't unload class System.out.println("-----------------------------------------------------------"); System.out.println(cls.getCode()); @@ -99,6 +105,26 @@ public abstract class IntegrationTest extends TestUtils { return cls; } + private void decompile(JadxDecompiler jadx, ClassNode cls) { + List passes = Jadx.getPassesList(jadx.getArgs(), new File(outDir)); + ProcessClass.process(cls, passes, new CodeGen(jadx.getArgs())); + } + + private void decompileWithoutUnload(JadxDecompiler d, ClassNode cls) { + cls.load(); + List passes = Jadx.getPassesList(d.getArgs(), new File(outDir)); + for (IDexTreeVisitor visitor : passes) { + DepthTraversal.visit(visitor, cls); + } + try { + new CodeGen(d.getArgs()).visit(cls); + } catch (CodegenException e) { + e.printStackTrace(); + fail(e.getMessage()); + } + // don't unload class + } + private static void checkCode(ClassNode cls) { assertTrue("Inconsistent cls: " + cls, !cls.contains(AFlag.INCONSISTENT_CODE) && !cls.contains(AType.JADX_ERROR)); @@ -109,8 +135,8 @@ public abstract class IntegrationTest extends TestUtils { assertThat(cls.getCode().toString(), not(containsString("inconsistent"))); } - protected List getPasses() { - return Jadx.getPassesList(new DefaultJadxArgs() { + private DefaultJadxArgs getArgs() { + return new DefaultJadxArgs() { @Override public boolean isCFGOutput() { return outputCFG; @@ -140,7 +166,7 @@ public abstract class IntegrationTest extends TestUtils { public boolean isSkipResources() { return true; } - }, new File(outDir)); + }; } private void runAutoCheck(String clsName) { @@ -363,6 +389,10 @@ public abstract class IntegrationTest extends TestUtils { this.compile = false; } + protected void dontUnloadClass() { + this.unloadCls = false; + } + // Use only for debug purpose @Deprecated protected void setOutputCFG() { diff --git a/jadx-core/src/test/java/jadx/tests/integration/TestDuplicateCast.java b/jadx-core/src/test/java/jadx/tests/integration/TestDuplicateCast.java index ef9594c0b..3b6fa458f 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/TestDuplicateCast.java +++ b/jadx-core/src/test/java/jadx/tests/integration/TestDuplicateCast.java @@ -31,6 +31,7 @@ public class TestDuplicateCast extends IntegrationTest { @Test public void test() { + dontUnloadClass(); ClassNode cls = getClassNode(TestCls.class); MethodNode mth = getMethod(cls, "method");