diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/RenameVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/RenameVisitor.java index 8df1daea6..a713f9d68 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/RenameVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/RenameVisitor.java @@ -5,6 +5,8 @@ import java.util.HashSet; import java.util.List; import java.util.Set; +import org.jetbrains.annotations.Nullable; + import jadx.api.JadxArgs; import jadx.core.Consts; import jadx.core.deobf.Deobfuscator; @@ -65,6 +67,9 @@ public class RenameVisitor extends AbstractVisitor { } } } + if (args.isRenameValid()) { + checkFieldsCollisionWithRootPackage(classes); + } } private void checkClassName(ClassNode cls, JadxArgs args) { @@ -134,4 +139,42 @@ public class RenameVisitor extends AbstractVisitor { } } } + + private void checkFieldsCollisionWithRootPackage(List classes) { + Set rootPkgs = collectRootPkgs(classes); + for (ClassNode cls : classes) { + for (FieldNode field : cls.getFields()) { + if (rootPkgs.contains(field.getAlias())) { + deobfuscator.forceRenameField(field); + } + } + } + } + + private static Set collectRootPkgs(List classes) { + Set fullPkgs = new HashSet<>(); + for (ClassNode cls : classes) { + fullPkgs.add(cls.getAlias().getPackage()); + } + Set rootPkgs = new HashSet<>(); + for (String pkg : fullPkgs) { + String rootPkg = getRootPkg(pkg); + if (rootPkg != null) { + rootPkgs.add(rootPkg); + } + } + return rootPkgs; + } + + @Nullable + private static String getRootPkg(String pkg) { + if (pkg.isEmpty()) { + return null; + } + int dotPos = pkg.indexOf('.'); + if (dotPos < 0) { + return pkg; + } + return pkg.substring(0, dotPos); + } } diff --git a/jadx-core/src/test/java/jadx/tests/api/IntegrationTest.java b/jadx-core/src/test/java/jadx/tests/api/IntegrationTest.java index 21e43b601..95348db69 100644 --- a/jadx-core/src/test/java/jadx/tests/api/IntegrationTest.java +++ b/jadx-core/src/test/java/jadx/tests/api/IntegrationTest.java @@ -124,10 +124,20 @@ public abstract class IntegrationTest extends TestUtils { assertThat("Class not found: " + clsName, cls, notNullValue()); assertThat(clsName, is(cls.getClassInfo().getFullName())); - decompileAndCheckCls(d, cls); + decompileAndCheck(d, Collections.singletonList(cls)); return cls; } + public ClassNode searchCls(List list, String fullClsName) { + for (ClassNode cls : list) { + if (cls.getClassInfo().getFullName().equals(fullClsName)) { + return cls; + } + } + fail("Class not found by name " + fullClsName + " in list: " + list); + return null; + } + protected JadxDecompiler loadFiles(List inputFiles) { JadxDecompiler d = null; try { @@ -137,26 +147,29 @@ public abstract class IntegrationTest extends TestUtils { } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); + return null; } RootNode root = JadxInternalAccess.getRoot(d); insertResources(root); return d; } - protected void decompileAndCheckCls(JadxDecompiler d, ClassNode cls) { + protected void decompileAndCheck(JadxDecompiler d, List clsList) { if (unloadCls) { - decompile(d, cls); + clsList.forEach(cls -> decompile(d, cls)); } else { - decompileWithoutUnload(d, cls); + clsList.forEach(cls -> decompileWithoutUnload(d, cls)); } - System.out.println("-----------------------------------------------------------"); - System.out.println(cls.getCode()); + for (ClassNode cls : clsList) { + System.out.println("-----------------------------------------------------------"); + System.out.println(cls.getCode()); + } System.out.println("-----------------------------------------------------------"); - checkCode(cls); - compile(cls); - runAutoCheck(cls.getClassInfo().getFullName()); + clsList.forEach(IntegrationTest::checkCode); + compile(clsList); + clsList.forEach(this::runAutoCheck); } private void insertResources(RootNode root) { @@ -221,7 +234,8 @@ public abstract class IntegrationTest extends TestUtils { return false; } - private void runAutoCheck(String clsName) { + private void runAutoCheck(ClassNode cls) { + String clsName = cls.getClassInfo().getFullName(); try { // run 'check' method from original class Class origCls; @@ -252,7 +266,7 @@ public abstract class IntegrationTest extends TestUtils { // run 'check' method from decompiled class if (compile) { try { - limitExecTime(() -> invoke("check")); + limitExecTime(() -> invoke(cls, "check")); } catch (Exception e) { rethrow("Decompiled check failed", e); } @@ -306,11 +320,15 @@ public abstract class IntegrationTest extends TestUtils { } void compile(ClassNode cls) { + compile(Collections.singletonList(cls)); + } + + void compile(List clsList) { if (!compile) { return; } try { - dynamicCompiler = new DynamicCompiler(cls); + dynamicCompiler = new DynamicCompiler(clsList); boolean result = dynamicCompiler.compile(); assertTrue(result, "Compilation failed"); System.out.println("Compilation: PASSED"); @@ -319,30 +337,13 @@ public abstract class IntegrationTest extends TestUtils { } } - public Object invoke(String method) throws Exception { - return invoke(method, new Class[0]); + public Object invoke(ClassNode cls, String method) throws Exception { + return invoke(cls, method, new Class[0]); } - public Object invoke(String method, Class[] types, Object... args) throws Exception { - Method mth = getReflectMethod(method, types); - return invoke(mth, args); - } - - public Method getReflectMethod(String method, Class... types) { + public Object invoke(ClassNode cls, String methodName, Class[] types, Object... args) throws Exception { assertNotNull(dynamicCompiler, "dynamicCompiler not ready"); - try { - return dynamicCompiler.getMethod(method, types); - } catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - return null; - } - - public Object invoke(Method mth, Object... args) throws Exception { - assertNotNull(dynamicCompiler, "dynamicCompiler not ready"); - assertNotNull(mth, "unknown method"); - return dynamicCompiler.invoke(mth, args); + return dynamicCompiler.invoke(cls, methodName, types, args); } public File getJarForClass(Class cls) throws IOException { diff --git a/jadx-core/src/test/java/jadx/tests/api/SmaliTest.java b/jadx-core/src/test/java/jadx/tests/api/SmaliTest.java index 79875d7c5..bd51d03a2 100644 --- a/jadx-core/src/test/java/jadx/tests/api/SmaliTest.java +++ b/jadx-core/src/test/java/jadx/tests/api/SmaliTest.java @@ -56,9 +56,7 @@ public abstract class SmaliTest extends IntegrationTest { JadxDecompiler d = loadFiles(Collections.singletonList(outDex)); RootNode root = JadxInternalAccess.getRoot(d); List classes = root.getClasses(false); - for (ClassNode cls : classes) { - decompileAndCheckCls(d, cls); - } + decompileAndCheck(d, classes); return classes; } diff --git a/jadx-core/src/test/java/jadx/tests/api/compiler/DynamicCompiler.java b/jadx-core/src/test/java/jadx/tests/api/compiler/DynamicCompiler.java index 674d29e32..f73ab7b07 100644 --- a/jadx-core/src/test/java/jadx/tests/api/compiler/DynamicCompiler.java +++ b/jadx-core/src/test/java/jadx/tests/api/compiler/DynamicCompiler.java @@ -2,6 +2,7 @@ package jadx.tests.api.compiler; import java.lang.reflect.Method; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import javax.tools.JavaCompiler; @@ -9,31 +10,28 @@ import javax.tools.JavaFileManager; import javax.tools.JavaFileObject; import javax.tools.ToolProvider; +import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import jadx.core.dex.nodes.ClassNode; import static javax.tools.JavaCompiler.CompilationTask; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.fail; public class DynamicCompiler { private static final Logger LOG = LoggerFactory.getLogger(DynamicCompiler.class); - private final ClassNode clsNode; - + private final List clsNodeList; private JavaFileManager fileManager; - private Object instance; - - public DynamicCompiler(ClassNode clsNode) { - this.clsNode = clsNode; + public DynamicCompiler(List clsNodeList) { + this.clsNodeList = clsNodeList; } - public boolean compile() throws Exception { - String fullName = clsNode.getFullName(); - String code = clsNode.getCode().toString(); - + public boolean compile() { JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); if (compiler == null) { LOG.error("Can not find compiler, please use JDK instead"); @@ -41,8 +39,10 @@ public class DynamicCompiler { } fileManager = new ClassFileManager(compiler.getStandardFileManager(null, null, null)); - List jFiles = new ArrayList<>(1); - jFiles.add(new CharSequenceJavaFileObject(fullName, code)); + List jFiles = new ArrayList<>(clsNodeList.size()); + for (ClassNode clsNode : clsNodeList) { + jFiles.add(new CharSequenceJavaFileObject(clsNode.getFullName(), clsNode.getCode().toString())); + } CompilationTask compilerTask = compiler.getTask(null, fileManager, null, null, null, jFiles); return Boolean.TRUE.equals(compilerTask.call()); @@ -52,27 +52,29 @@ public class DynamicCompiler { return fileManager.getClassLoader(null); } - private void makeInstance() throws Exception { - String fullName = clsNode.getFullName(); - instance = getClassLoader().loadClass(fullName).getConstructor().newInstance(); + public Object makeInstance(ClassNode cls) throws Exception { + String fullName = cls.getFullName(); + return getClassLoader().loadClass(fullName).getConstructor().newInstance(); } - private Object getInstance() throws Exception { - if (instance == null) { - makeInstance(); - } - return instance; - } - - public Method getMethod(String method, Class[] types) throws Exception { + @NotNull + public Method getMethod(Object inst, String methodName, Class[] types) throws Exception { for (Class type : types) { checkType(type); } - return getInstance().getClass().getMethod(method, types); + return inst.getClass().getMethod(methodName, types); } - public Object invoke(Method mth, Object... args) throws Exception { - return mth.invoke(getInstance(), args); + public Object invoke(ClassNode cls, String methodName, Class[] types, Object[] args) { + try { + Object inst = makeInstance(cls); + Method reflMth = getMethod(inst, methodName, types); + assertNotNull(reflMth, "Failed to get method " + methodName + '(' + Arrays.toString(types) + ')'); + return reflMth.invoke(inst, args); + } catch (Exception e) { + fail(e.getMessage(), e); + return null; + } } private Class checkType(Class type) throws ClassNotFoundException { diff --git a/jadx-core/src/test/java/jadx/tests/integration/loops/TestBreakWithLabel.java b/jadx-core/src/test/java/jadx/tests/integration/loops/TestBreakWithLabel.java index 0747a56cd..b30b5a74f 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/loops/TestBreakWithLabel.java +++ b/jadx-core/src/test/java/jadx/tests/integration/loops/TestBreakWithLabel.java @@ -1,7 +1,5 @@ package jadx.tests.integration.loops; -import java.lang.reflect.Method; - import org.junit.jupiter.api.Test; import jadx.core.dex.nodes.ClassNode; @@ -29,6 +27,12 @@ public class TestBreakWithLabel extends IntegrationTest { System.out.println("found: " + found); return found; } + + public void check() { + int[][] testArray = { { 1, 2 }, { 3, 4 } }; + assertTrue(test(testArray, 3)); + assertFalse(test(testArray, 5)); + } } @Test @@ -38,10 +42,5 @@ public class TestBreakWithLabel extends IntegrationTest { assertThat(code, containsOne("loop0:")); assertThat(code, containsOne("break loop0;")); - - Method test = getReflectMethod("test", int[][].class, int.class); - int[][] testArray = { { 1, 2 }, { 3, 4 } }; - assertTrue((Boolean) invoke(test, testArray, 3)); - assertFalse((Boolean) invoke(test, testArray, 5)); } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/names/TestFieldCollideWithPackage.java b/jadx-core/src/test/java/jadx/tests/integration/names/TestFieldCollideWithPackage.java new file mode 100644 index 000000000..7b8dc7168 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/names/TestFieldCollideWithPackage.java @@ -0,0 +1,62 @@ +package jadx.tests.integration.names; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.SmaliTest; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.not; + +public class TestFieldCollideWithPackage extends SmaliTest { + //@formatter:off + /* + ----------------------------------------------------------- + package first; + + public class A { + public A first; + public second.A second; + + public String test() { + return second.A.call(); // compiler treat 'second' as field name + } + } + ----------------------------------------------------------- + package second; + + public class A { + public static String call() { + return null; + } + } + ----------------------------------------------------------- + */ + //@formatter:on + + @Test + public void test() { + List clsList = loadFromSmaliFiles(); + ClassNode firstA = searchCls(clsList, "first.A"); + String code = firstA.getCode().toString(); + + assertThat(code, containsString("second.A")); + // expect field to be renamed + assertThat(code, not(containsString("public second.A second;"))); + } + + @Test + public void testWithoutImports() { + getArgs().setUseImports(false); + loadFromSmaliFiles(); + } + + @Test + public void testWithDeobfuscation() { + enableDeobfuscation(); + loadFromSmaliFiles(); + } +} diff --git a/jadx-core/src/test/smali/names/TestFieldCollideWithPackage/1.smali b/jadx-core/src/test/smali/names/TestFieldCollideWithPackage/1.smali new file mode 100644 index 000000000..1e8e304fc --- /dev/null +++ b/jadx-core/src/test/smali/names/TestFieldCollideWithPackage/1.smali @@ -0,0 +1,15 @@ +.class public Lfirst/A; +.super Ljava/lang/Object; + +.field public first:Lfirst/A; +.field public second:Lsecond/A; + +.method public test()Ljava/lang/String; + .registers 2 + + invoke-static {}, Lsecond/A;->call()Ljava/lang/String; + + move-result-object v0 + + return-object v0 +.end method diff --git a/jadx-core/src/test/smali/names/TestFieldCollideWithPackage/2.smali b/jadx-core/src/test/smali/names/TestFieldCollideWithPackage/2.smali new file mode 100644 index 000000000..59046fb39 --- /dev/null +++ b/jadx-core/src/test/smali/names/TestFieldCollideWithPackage/2.smali @@ -0,0 +1,10 @@ +.class public Lsecond/A; +.super Ljava/lang/Object; + +.method static public call()Ljava/lang/String; + .registers 1 + + const v0, 0 + + return-object v0 +.end method