fix: check enum constructor content before removing (#922)

This commit is contained in:
Skylot
2020-05-03 18:16:59 +01:00
parent 2dce1c0ad9
commit f3cd4e38d7
7 changed files with 198 additions and 19 deletions
@@ -9,6 +9,8 @@ import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.jetbrains.annotations.Nullable;
import com.android.dx.rop.code.AccessFlags;
import com.google.common.collect.Streams;
@@ -21,6 +23,7 @@ import jadx.core.dex.attributes.nodes.EnumClassAttr;
import jadx.core.dex.attributes.nodes.EnumClassAttr.EnumField;
import jadx.core.dex.attributes.nodes.JadxError;
import jadx.core.dex.attributes.nodes.LineAttrNode;
import jadx.core.dex.attributes.nodes.SkipMethodArgsAttr;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.instructions.args.ArgType;
@@ -420,12 +423,13 @@ public class ClassGen {
EnumField f = it.next();
code.startLine(f.getField().getAlias());
ConstructorInsn constrInsn = f.getConstrInsn();
if (constrInsn.getArgsCount() > f.getStartArg()) {
MethodNode callMth = cls.dex().resolveMethod(constrInsn.getCallMth());
int skipCount = getEnumCtrSkipArgsCount(callMth);
if (constrInsn.getArgsCount() > skipCount) {
if (igen == null) {
igen = makeInsnGen(enumFields.getStaticMethod());
}
MethodNode callMth = cls.dex().resolveMethod(constrInsn.getCallMth());
igen.generateMethodArguments(code, constrInsn, f.getStartArg(), callMth);
igen.generateMethodArguments(code, constrInsn, 0, callMth);
}
if (f.getCls() != null) {
code.add(' ');
@@ -446,6 +450,16 @@ public class ClassGen {
}
}
private int getEnumCtrSkipArgsCount(@Nullable MethodNode callMth) {
if (callMth != null) {
SkipMethodArgsAttr skipArgsAttr = callMth.get(AType.SKIP_MTH_ARGS);
if (skipArgsAttr != null) {
return skipArgsAttr.getSkipCount();
}
}
return 0;
}
private InsnGen makeInsnGen(MethodNode mth) {
MethodGen mthGen = new MethodGen(this, mth);
return new InsnGen(mthGen, false);
@@ -14,13 +14,11 @@ public class EnumClassAttr implements IAttribute {
public static class EnumField {
private final FieldNode field;
private final ConstructorInsn constrInsn;
private final int startArg;
private ClassNode cls;
public EnumField(FieldNode field, ConstructorInsn co, int startArg) {
public EnumField(FieldNode field, ConstructorInsn co) {
this.field = field;
this.constrInsn = co;
this.startArg = startArg;
}
public FieldNode getField() {
@@ -31,10 +29,6 @@ public class EnumClassAttr implements IAttribute {
return constrInsn;
}
public int getStartArg() {
return startArg;
}
public ClassNode getCls() {
return cls;
}
@@ -55,6 +55,10 @@ public class SkipMethodArgsAttr implements IAttribute {
return skipArgs.get(argNum);
}
public int getSkipCount() {
return skipArgs.cardinality();
}
@Override
public AType<SkipMethodArgsAttr> getType() {
return AType.SKIP_MTH_ARGS;
@@ -19,6 +19,7 @@ import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.EnumClassAttr;
import jadx.core.dex.attributes.nodes.EnumClassAttr.EnumField;
import jadx.core.dex.attributes.nodes.SkipMethodArgsAttr;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.info.ClassInfo;
import jadx.core.dex.info.FieldInfo;
@@ -287,8 +288,13 @@ public class EnumVisitor extends AbstractVisitor {
if (!clsInfo.equals(cls.getClassInfo()) && !constrCls.getAccessFlags().isEnum()) {
return null;
}
int startArg = co.getArgsCount() == 1 ? 1 : 2;
return new EnumField(enumFieldNode, co, startArg);
MethodInfo callMth = co.getCallMth();
MethodNode mth = cls.dex().resolveMethod(callMth);
if (mth == null) {
return null;
}
markArgsForSkip(mth);
return new EnumField(enumFieldNode, co);
}
@Nullable
@@ -306,10 +312,7 @@ public class EnumVisitor extends AbstractVisitor {
}
private void removeEnumMethods(ClassNode cls, ArgType clsType, FieldNode valuesField) {
String enumConstructor = "<init>(Ljava/lang/String;I)V";
String enumConstructorAlt = "<init>(Ljava/lang/String;)V";
String valuesMethod = "values()" + TypeGen.signature(ArgType.array(clsType));
FieldInfo valuesFieldInfo = valuesField.getFieldInfo();
// remove compiler generated methods
@@ -319,12 +322,11 @@ public class EnumVisitor extends AbstractVisitor {
continue;
}
String shortId = mi.getShortId();
boolean isSynthetic = mth.getAccessFlags().isSynthetic();
if (mi.isConstructor() && !isSynthetic) {
if (shortId.equals(enumConstructor)
|| shortId.equals(enumConstructorAlt)) {
if (mi.isConstructor()) {
if (isDefaultConstructor(mth, shortId)) {
mth.add(AFlag.DONT_GENERATE);
}
markArgsForSkip(mth);
} else if (shortId.equals(valuesMethod)
|| usesValuesField(mth, valuesFieldInfo)
|| simpleValueOfMth(mth, clsType)) {
@@ -333,6 +335,24 @@ public class EnumVisitor extends AbstractVisitor {
}
}
private void markArgsForSkip(MethodNode mth) {
// skip first and second args
SkipMethodArgsAttr.skipArg(mth, 0);
if (mth.getMethodInfo().getArgsCount() > 1) {
SkipMethodArgsAttr.skipArg(mth, 1);
}
}
private boolean isDefaultConstructor(MethodNode mth, String shortId) {
boolean defaultId = shortId.equals("<init>(Ljava/lang/String;I)V")
|| shortId.equals("<init>(Ljava/lang/String;)V");
if (defaultId) {
// check content
return mth.countInsns() == 0;
}
return false;
}
private boolean simpleValueOfMth(MethodNode mth, ArgType clsType) {
InsnNode returnInsn = InsnUtils.searchSingleReturnInsn(mth, insn -> insn.getArgsCount() == 1);
if (returnInsn == null) {
@@ -0,0 +1,60 @@
package jadx.tests.integration.enums;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
import static jadx.tests.integration.enums.TestEnums2a.TestCls.DoubleOperations.DIVIDE;
import static jadx.tests.integration.enums.TestEnums2a.TestCls.DoubleOperations.TIMES;
public class TestEnums2a extends IntegrationTest {
public static class TestCls {
public interface IOps {
double apply(double x, double y);
}
public enum DoubleOperations implements IOps {
TIMES("*") {
@Override
public double apply(double x, double y) {
return x * y;
}
},
DIVIDE("/") {
@Override
public double apply(double x, double y) {
return x / y;
}
};
private final String op;
DoubleOperations(String op) {
this.op = op;
}
public String getOp() {
return op;
}
}
public void check() {
assertThat(TIMES.getOp()).isEqualTo("*");
assertThat(DIVIDE.getOp()).isEqualTo("/");
assertThat(TIMES.apply(2, 3)).isEqualTo(6);
assertThat(DIVIDE.apply(10, 5)).isEqualTo(2);
}
}
@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("TIMES(\"*\") {")
.containsOne("DIVIDE(\"/\")");
}
}
@@ -0,0 +1,46 @@
package jadx.tests.integration.enums;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestEnums6 extends IntegrationTest {
public static class TestCls {
public enum Numbers {
ZERO,
ONE(1);
private final int n;
Numbers() {
this(0);
}
Numbers(int n) {
this.n = n;
}
public int getN() {
return n;
}
}
public void check() {
Assertions.assertThat(TestCls.Numbers.ZERO.getN()).isEqualTo(0);
Assertions.assertThat(TestCls.Numbers.ONE.getN()).isEqualTo(1);
}
}
@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("ZERO,")
.containsOne("Numbers() {")
.containsOne("ONE(1);");
}
}
@@ -0,0 +1,41 @@
package jadx.tests.integration.enums;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestEnums7 extends IntegrationTest {
public static class TestCls {
public enum Numbers {
ZERO,
ONE;
private final int n;
Numbers() {
this.n = this.name().equals("ZERO") ? 0 : 1;
}
public int getN() {
return n;
}
}
public void check() {
assertThat(Numbers.ZERO.getN()).isEqualTo(0);
assertThat(Numbers.ONE.getN()).isEqualTo(1);
}
}
@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("ZERO,")
.containsOne("ONE;")
.containsOne("Numbers() {");
}
}