From 8969d11a22732bb3a8fa1b0e13143886fda2b12b Mon Sep 17 00:00:00 2001 From: Skylot Date: Tue, 14 Sep 2021 19:17:57 +0100 Subject: [PATCH] fix: restore fields order on init code move (#678) --- .../java/jadx/core/dex/nodes/InsnNode.java | 23 ++ .../core/dex/visitors/ExtractFieldInit.java | 331 ++++++++++-------- .../main/java/jadx/core/utils/BlockUtils.java | 24 ++ .../java/jadx/tests/api/utils/TestUtils.java | 3 + .../integration/others/TestFieldInit.java | 46 --- .../others/TestFieldInitInTryCatch.java | 4 +- .../others/TestFieldInitOrder.java | 34 ++ .../others/TestFieldInitOrderStatic.java | 35 ++ 8 files changed, 302 insertions(+), 198 deletions(-) delete mode 100644 jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInit.java create mode 100644 jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitOrder.java create mode 100644 jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitOrderStatic.java diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java b/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java index 711fd3eaf..f546ea2f6 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java @@ -6,6 +6,7 @@ import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.function.Consumer; +import java.util.function.Function; import org.jetbrains.annotations.Nullable; @@ -304,6 +305,28 @@ public class InsnNode extends LineAttrNode { } } + /** + * Visit this instruction and all inner (wrapped) instructions + * To terminate visiting return non-null value + */ + @Nullable + public R visitInsns(Function visitor) { + R result = visitor.apply(this); + if (result != null) { + return result; + } + for (InsnArg arg : this.getArguments()) { + if (arg.isInsnWrap()) { + InsnNode innerInsn = ((InsnWrapArg) arg).getWrapInsn(); + R res = innerInsn.visitInsns(visitor); + if (res != null) { + return res; + } + } + } + return null; + } + /** * 'Soft' equals, don't compare arguments, only instruction specific parameters. */ diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/ExtractFieldInit.java b/jadx-core/src/main/java/jadx/core/dex/visitors/ExtractFieldInit.java index 1f22362ab..506f5e5eb 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/ExtractFieldInit.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/ExtractFieldInit.java @@ -4,7 +4,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; import jadx.api.plugins.input.data.attributes.JadxAttrType; import jadx.core.dex.attributes.AFlag; @@ -25,6 +27,7 @@ import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.visitors.shrink.CodeShrinkVisitor; import jadx.core.utils.BlockUtils; import jadx.core.utils.InsnRemover; +import jadx.core.utils.Utils; import jadx.core.utils.exceptions.JadxException; @JadxVisitor( @@ -40,193 +43,191 @@ public class ExtractFieldInit extends AbstractVisitor { for (ClassNode inner : cls.getInnerClasses()) { visit(inner); } - checkStaticFieldsInit(cls); moveStaticFieldsInit(cls); moveCommonFieldsInit(cls); return false; } - private static void checkStaticFieldsInit(ClassNode cls) { - MethodNode clinit = cls.getClassInitMth(); - if (clinit == null - || !clinit.getAccessFlags().isStatic() - || clinit.isNoCode() - || clinit.getBasicBlocks() == null) { - return; - } + private static final class FieldInitInfo { + final FieldNode fieldNode; + final IndexInsnNode putInsn; + final boolean singlePath; - for (BlockNode block : clinit.getBasicBlocks()) { - for (InsnNode insn : block.getInstructions()) { - if (insn.getType() == InsnType.SPUT) { - processStaticFieldAssign(cls, (IndexInsnNode) insn); - } - } + public FieldInitInfo(FieldNode fieldNode, IndexInsnNode putInsn, boolean singlePath) { + this.fieldNode = fieldNode; + this.putInsn = putInsn; + this.singlePath = singlePath; } } - /** - * Remove a final field in place initialization if it an assign found in class init method - */ - private static void processStaticFieldAssign(ClassNode cls, IndexInsnNode insn) { - FieldInfo field = (FieldInfo) insn.getIndex(); - if (field.getDeclClass().equals(cls.getClassInfo())) { - FieldNode fn = cls.searchField(field); - if (fn != null && fn.getAccessFlags().isFinal()) { - fn.remove(JadxAttrType.CONSTANT_VALUE); - } + private static final class ConstructorInitInfo { + final MethodNode constructorMth; + final List fieldInits; + + private ConstructorInitInfo(MethodNode constructorMth, List fieldInits) { + this.constructorMth = constructorMth; + this.fieldInits = fieldInits; } } private static void moveStaticFieldsInit(ClassNode cls) { MethodNode classInitMth = cls.getClassInitMth(); - if (classInitMth == null) { + if (classInitMth == null + || !classInitMth.getAccessFlags().isStatic() + || classInitMth.isNoCode() + || classInitMth.getBasicBlocks() == null) { return; } - while (processFields(cls, classInitMth)) { + while (processStaticFields(cls, classInitMth)) { // sometimes instructions moved to field init prevent from vars inline -> inline and try again CodeShrinkVisitor.shrinkMethod(classInitMth); } } - private static boolean processFields(ClassNode cls, MethodNode classInitMth) { - boolean changed = false; - for (FieldNode field : cls.getFields()) { - if (field.contains(AFlag.DONT_GENERATE) || field.contains(AType.FIELD_INIT_INSN)) { - continue; - } - if (field.getAccessFlags().isStatic()) { - List initInsns = getFieldAssigns(classInitMth, field, InsnType.SPUT); - if (initInsns.size() == 1) { - InsnNode insn = initInsns.get(0); - if (checkInsn(cls, insn)) { - InsnArg arg = insn.getArg(0); - if (arg instanceof InsnWrapArg) { - ((InsnWrapArg) arg).getWrapInsn().add(AFlag.DECLARE_VAR); - } - InsnRemover.remove(classInitMth, insn); - addFieldInitAttr(classInitMth, field, insn); - changed = true; - } - } + private static boolean processStaticFields(ClassNode cls, MethodNode classInitMth) { + List inits = collectFieldsInit(cls, classInitMth, InsnType.SPUT); + if (inits.isEmpty()) { + return false; + } + // ignore field init constant if field initialized in class init method + for (FieldInitInfo fieldInit : inits) { + FieldNode field = fieldInit.fieldNode; + if (field.getAccessFlags().isFinal()) { + field.remove(JadxAttrType.CONSTANT_VALUE); } } - return changed; - } - - private static class InitInfo { - private final MethodNode constrMth; - private final List putInsns = new ArrayList<>(); - - private InitInfo(MethodNode constrMth) { - this.constrMth = constrMth; + filterFieldsInit(inits); + if (inits.isEmpty()) { + return false; } - - public MethodNode getConstrMth() { - return constrMth; - } - - public List getPutInsns() { - return putInsns; + for (FieldInitInfo fieldInit : inits) { + IndexInsnNode insn = fieldInit.putInsn; + InsnArg arg = insn.getArg(0); + if (arg instanceof InsnWrapArg) { + ((InsnWrapArg) arg).getWrapInsn().add(AFlag.DECLARE_VAR); + } + InsnRemover.remove(classInitMth, insn); + addFieldInitAttr(classInitMth, fieldInit.fieldNode, insn); } + fixFieldsOrder(cls, inits); + return true; } private static void moveCommonFieldsInit(ClassNode cls) { - List constrList = getConstructorsList(cls); - if (constrList.isEmpty()) { + List constructors = getConstructorsList(cls); + if (constructors.isEmpty()) { return; } - List infoList = new ArrayList<>(constrList.size()); - for (MethodNode constrMth : constrList) { - if (constrMth.isNoCode()) { + List infoList = new ArrayList<>(constructors.size()); + for (MethodNode constructorMth : constructors) { + if (constructorMth.isNoCode()) { return; } - List enterBlocks = constrMth.getEnterBlock().getCleanSuccessors(); - if (enterBlocks.isEmpty()) { + List inits = collectFieldsInit(cls, constructorMth, InsnType.IPUT); + filterFieldsInit(inits); + if (inits.isEmpty()) { return; } - InitInfo info = new InitInfo(constrMth); - infoList.add(info); - // TODO: check not only first block - BlockNode blockNode = enterBlocks.get(0); - for (InsnNode insn : blockNode.getInstructions()) { - if (insn.getType() == InsnType.IPUT && checkInsn(cls, insn)) { - info.getPutInsns().add(insn); - } else if (!info.getPutInsns().isEmpty()) { - break; - } - } + infoList.add(new ConstructorInitInfo(constructorMth, inits)); } // compare collected instructions - InitInfo common = null; - for (InitInfo info : infoList) { + ConstructorInitInfo common = null; + for (ConstructorInitInfo info : infoList) { if (common == null) { common = info; - } else if (!compareInsns(common.getPutInsns(), info.getPutInsns())) { + continue; + } + if (!compareFieldInits(common.fieldInits, info.fieldInits)) { return; } } if (common == null) { return; } - Set fields = new HashSet<>(); - for (InsnNode insn : common.getPutInsns()) { - FieldInfo fieldInfo = (FieldInfo) ((IndexInsnNode) insn).getIndex(); - FieldNode field = cls.root().resolveField(fieldInfo); - if (field == null) { - return; - } - if (!fields.add(fieldInfo)) { - return; - } - } // all checks passed - for (InitInfo info : infoList) { - for (InsnNode putInsn : info.getPutInsns()) { + for (ConstructorInitInfo info : infoList) { + for (FieldInitInfo fieldInit : info.fieldInits) { + IndexInsnNode putInsn = fieldInit.putInsn; InsnArg arg = putInsn.getArg(0); if (arg instanceof InsnWrapArg) { ((InsnWrapArg) arg).getWrapInsn().add(AFlag.DECLARE_VAR); } - InsnRemover.remove(info.getConstrMth(), putInsn); + InsnRemover.remove(info.constructorMth, putInsn); } } - for (InsnNode insn : common.getPutInsns()) { - FieldInfo fieldInfo = (FieldInfo) ((IndexInsnNode) insn).getIndex(); - FieldNode field = cls.root().resolveField(fieldInfo); - addFieldInitAttr(common.getConstrMth(), field, insn); + for (FieldInitInfo fieldInit : common.fieldInits) { + addFieldInitAttr(common.constructorMth, fieldInit.fieldNode, fieldInit.putInsn); + } + fixFieldsOrder(cls, common.fieldInits); + } + + private static List collectFieldsInit(ClassNode cls, MethodNode mth, InsnType putType) { + List fieldsInit = new ArrayList<>(); + Set singlePathBlocks = new HashSet<>(); + BlockUtils.visitSinglePath(mth.getEnterBlock(), singlePathBlocks::add); + + for (BlockNode block : mth.getBasicBlocks()) { + for (InsnNode insn : block.getInstructions()) { + if (insn.getType() == putType) { + IndexInsnNode putInsn = (IndexInsnNode) insn; + FieldInfo field = (FieldInfo) putInsn.getIndex(); + if (field.getDeclClass().equals(cls.getClassInfo())) { + FieldNode fn = cls.searchField(field); + if (fn != null) { + boolean singlePath = singlePathBlocks.contains(block); + fieldsInit.add(new FieldInitInfo(fn, putInsn, singlePath)); + } + } + } + } + } + return fieldsInit; + } + + private static void filterFieldsInit(List inits) { + // exclude fields initialized several times + Set excludedFields = inits + .stream() + .collect(Collectors.toMap(fi -> fi.fieldNode, fi -> 1, Integer::sum)) + .entrySet() + .stream() + .filter(v -> v.getValue() > 1) + .map(v -> v.getKey().getFieldInfo()) + .collect(Collectors.toSet()); + + for (FieldInitInfo initInfo : inits) { + if (!checkInsn(initInfo)) { + excludedFields.add(initInfo.fieldNode.getFieldInfo()); + } + } + if (!excludedFields.isEmpty()) { + boolean changed; + do { + changed = false; + for (FieldInitInfo initInfo : inits) { + FieldInfo fieldInfo = initInfo.fieldNode.getFieldInfo(); + if (excludedFields.contains(fieldInfo)) { + continue; + } + if (insnUseExcludedField(initInfo, excludedFields)) { + excludedFields.add(fieldInfo); + changed = true; + } + } + } while (changed); + } + + // apply + if (!excludedFields.isEmpty()) { + inits.removeIf(fi -> excludedFields.contains(fi.fieldNode.getFieldInfo())); } } - private static boolean compareInsns(List base, List other) { - if (base.size() != other.size()) { + private static boolean checkInsn(FieldInitInfo initInfo) { + if (!initInfo.singlePath) { return false; } - int count = base.size(); - for (int i = 0; i < count; i++) { - InsnNode baseInsn = base.get(i); - InsnNode otherInsn = other.get(i); - if (!baseInsn.isSame(otherInsn)) { - return false; - } - } - return true; - } - - private static boolean checkInsn(ClassNode cls, InsnNode insn) { - if (insn instanceof IndexInsnNode) { - FieldInfo fieldInfo = (FieldInfo) ((IndexInsnNode) insn).getIndex(); - if (!fieldInfo.getDeclClass().equals(cls.getClassInfo())) { - // exclude fields from super classes - return false; - } - FieldNode fieldNode = cls.root().resolveField(fieldInfo); - if (fieldNode == null) { - // exclude inherited fields (not declared in this class) - return false; - } - } else { - return false; - } - + IndexInsnNode insn = initInfo.putInsn; InsnArg arg = insn.getArg(0); if (arg.isInsnWrap()) { InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn(); @@ -248,6 +249,52 @@ public class ExtractFieldInit extends AbstractVisitor { return true; } + private static boolean insnUseExcludedField(FieldInitInfo initInfo, Set excludedFields) { + if (excludedFields.isEmpty()) { + return false; + } + IndexInsnNode insn = initInfo.putInsn; + boolean staticField = insn.getType() == InsnType.SPUT; + InsnType useType = staticField ? InsnType.SGET : InsnType.IGET; + // exclude if init code use any excluded field + Boolean exclude = insn.visitInsns(innerInsn -> { + if (innerInsn.getType() == useType) { + FieldInfo fieldInfo = (FieldInfo) ((IndexInsnNode) innerInsn).getIndex(); + if (excludedFields.contains(fieldInfo)) { + return true; + } + } + return null; + }); + return Objects.equals(exclude, Boolean.TRUE); + } + + private static void fixFieldsOrder(ClassNode cls, List fieldsInit) { + List clsFields = cls.getFields(); + List orderedFields = Utils.collectionMap(fieldsInit, v -> v.fieldNode); + // check if already ordered + boolean ordered = Collections.indexOfSubList(clsFields, orderedFields) != -1; + if (!ordered) { + clsFields.removeAll(orderedFields); + clsFields.addAll(orderedFields); + } + } + + private static boolean compareFieldInits(List base, List other) { + if (base.size() != other.size()) { + return false; + } + int count = base.size(); + for (int i = 0; i < count; i++) { + InsnNode baseInsn = base.get(i).putInsn; + InsnNode otherInsn = other.get(i).putInsn; + if (!baseInsn.isSame(otherInsn)) { + return false; + } + } + return true; + } + private static List getConstructorsList(ClassNode cls) { List list = new ArrayList<>(); for (MethodNode mth : cls.getMethods()) { @@ -262,26 +309,8 @@ public class ExtractFieldInit extends AbstractVisitor { return list; } - private static List getFieldAssigns(MethodNode mth, FieldNode field, InsnType putInsn) { - if (mth.isNoCode() || mth.getBasicBlocks() == null) { - return Collections.emptyList(); - } - List assignInsns = new ArrayList<>(); - for (BlockNode block : mth.getBasicBlocks()) { - for (InsnNode insn : block.getInstructions()) { - if (insn.getType() == putInsn) { - FieldInfo putNode = (FieldInfo) ((IndexInsnNode) insn).getIndex(); - if (putNode.equals(field.getFieldInfo())) { - assignInsns.add(insn); - } - } - } - } - return assignInsns; - } - - private static void addFieldInitAttr(MethodNode classInitMth, FieldNode field, InsnNode insn) { + private static void addFieldInitAttr(MethodNode mth, FieldNode field, InsnNode insn) { InsnNode assignInsn = InsnNode.wrapArg(insn.getArg(0)); - field.addAttr(new FieldInitInsnAttr(classInitMth, assignInsn)); + field.addAttr(new FieldInitInsnAttr(mth, assignInsn)); } } diff --git a/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java b/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java index 426383af4..0e47ce93a 100644 --- a/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java @@ -761,6 +761,30 @@ public class BlockUtils { } } + /** + * Visit blocks on path without branching or merging paths. + */ + public static void visitSinglePath(BlockNode startBlock, Consumer visitor) { + if (startBlock == null) { + return; + } + visitor.accept(startBlock); + BlockNode next = getNextSinglePathBlock(startBlock); + while (next != null) { + visitor.accept(next); + next = getNextSinglePathBlock(next); + } + } + + @Nullable + public static BlockNode getNextSinglePathBlock(BlockNode block) { + if (block == null || block.getPredecessors().size() > 1) { + return null; + } + List successors = block.getSuccessors(); + return successors.size() == 1 ? successors.get(0) : null; + } + public static List buildSimplePath(BlockNode block) { if (block == null) { return Collections.emptyList(); diff --git a/jadx-core/src/test/java/jadx/tests/api/utils/TestUtils.java b/jadx-core/src/test/java/jadx/tests/api/utils/TestUtils.java index 21cee99c9..0f5dd297a 100644 --- a/jadx-core/src/test/java/jadx/tests/api/utils/TestUtils.java +++ b/jadx-core/src/test/java/jadx/tests/api/utils/TestUtils.java @@ -24,6 +24,9 @@ public class TestUtils { } public static int count(String string, String substring) { + if (substring == null || substring.isEmpty()) { + throw new IllegalArgumentException("Substring can't be null or empty"); + } int count = 0; int idx = 0; while ((idx = string.indexOf(substring, idx)) != -1) { diff --git a/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInit.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInit.java deleted file mode 100644 index 98406e78e..000000000 --- a/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInit.java +++ /dev/null @@ -1,46 +0,0 @@ -package jadx.tests.integration.others; - -import java.util.ArrayList; -import java.util.List; -import java.util.Random; - -import org.junit.jupiter.api.Test; - -import jadx.tests.api.IntegrationTest; - -import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; - -public class TestFieldInit extends IntegrationTest { - - public static class TestCls { - - public class A { - } - - public static List s = new ArrayList<>(); - - public A a = new A(); - public int i = 1 + Random.class.getSimpleName().length(); - public int n = 0; - - public TestCls(int z) { - this.n = z; - this.n = 0; - } - } - - @Test - public void test() { - assertThat(getClassNode(TestCls.class)) - .code() - .containsOne("List s = new ArrayList") - .containsOne("A a = new A();") - .containsOneOf( - "int i = (Random.class.getSimpleName().length() + 1);", - "int i = (1 + Random.class.getSimpleName().length());") - .containsOne("int n = 0;") - .doesNotContain("static {") - .containsOne("this.n = z;") - .containsOne("this.n = 0;"); - } -} diff --git a/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitInTryCatch.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitInTryCatch.java index 241196f5b..6b109064d 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitInTryCatch.java +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitInTryCatch.java @@ -82,6 +82,8 @@ public class TestFieldInitInTryCatch extends IntegrationTest { ClassNode cls = getClassNode(TestCls3.class); String code = cls.getCode().toString(); - assertThat(code, containsOne("public static final String[] A = {\"a\"};")); + // don't move code from try/catch + assertThat(code, containsOne("public static final String[] A;")); + assertThat(code, containsOne("A = new String[]{\"a\"};")); } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitOrder.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitOrder.java new file mode 100644 index 000000000..118d60404 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitOrder.java @@ -0,0 +1,34 @@ +package jadx.tests.integration.others; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestFieldInitOrder extends IntegrationTest { + + public static class TestCls { + private final StringBuilder sb = new StringBuilder(); + private final String a = sb.append("a").toString(); + private final String b = sb.append("b").toString(); + private final String c = sb.append("c").toString(); + private final String result = sb.toString(); + + public void check() { + assertThat(result).isEqualTo("abc"); + assertThat(a).isEqualTo("a"); + assertThat(b).isEqualTo("ab"); + assertThat(c).isEqualTo("abc"); + } + } + + @Test + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .doesNotContain("TestCls() {") // constructor removed + .doesNotContain("String result;") + .containsOne("String result = this.sb.toString();"); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitOrderStatic.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitOrderStatic.java new file mode 100644 index 000000000..2007f35f6 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestFieldInitOrderStatic.java @@ -0,0 +1,35 @@ +package jadx.tests.integration.others; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestFieldInitOrderStatic extends IntegrationTest { + + @SuppressWarnings("ConstantName") + public static class TestCls { + private static final StringBuilder sb = new StringBuilder(); + private static final String a = sb.append("a").toString(); + private static final String b = sb.append("b").toString(); + private static final String c = sb.append("c").toString(); + private static final String result = sb.toString(); + + public void check() { + assertThat(result).isEqualTo("abc"); + assertThat(a).isEqualTo("a"); + assertThat(b).isEqualTo("ab"); + assertThat(c).isEqualTo("abc"); + } + } + + @Test + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .doesNotContain("static {") + .doesNotContain("String result;") + .containsOne("String result = sb.toString();"); + } +}