diff --git a/jadx-core/src/main/java/jadx/core/Jadx.java b/jadx-core/src/main/java/jadx/core/Jadx.java index 91ea09058..506c8a4ab 100644 --- a/jadx-core/src/main/java/jadx/core/Jadx.java +++ b/jadx-core/src/main/java/jadx/core/Jadx.java @@ -14,6 +14,7 @@ import jadx.api.JadxArgs; import jadx.core.dex.visitors.ClassModifier; import jadx.core.dex.visitors.ConstInlineVisitor; import jadx.core.dex.visitors.ConstructorVisitor; +import jadx.core.dex.visitors.DeboxingVisitor; import jadx.core.dex.visitors.DependencyCollector; import jadx.core.dex.visitors.DotGraphVisitor; import jadx.core.dex.visitors.EnumVisitor; @@ -87,6 +88,7 @@ public class Jadx { passes.add(new DebugInfoApplyVisitor()); } + passes.add(new DeboxingVisitor()); passes.add(new ModVisitor()); passes.add(new CodeShrinkVisitor()); passes.add(new ReSugarCode()); diff --git a/jadx-core/src/main/java/jadx/core/codegen/AnnotationGen.java b/jadx-core/src/main/java/jadx/core/codegen/AnnotationGen.java index 9a582f6eb..7bcf11ab5 100644 --- a/jadx-core/src/main/java/jadx/core/codegen/AnnotationGen.java +++ b/jadx-core/src/main/java/jadx/core/codegen/AnnotationGen.java @@ -147,7 +147,7 @@ public class AnnotationGen { if (val instanceof String) { code.add(getStringUtils().unescapeString((String) val)); } else if (val instanceof Integer) { - code.add(TypeGen.formatInteger((Integer) val)); + code.add(TypeGen.formatInteger((Integer) val, false)); } else if (val instanceof Character) { code.add(getStringUtils().unescapeChar((Character) val)); } else if (val instanceof Boolean) { @@ -157,11 +157,11 @@ public class AnnotationGen { } else if (val instanceof Double) { code.add(TypeGen.formatDouble((Double) val)); } else if (val instanceof Long) { - code.add(TypeGen.formatLong((Long) val)); + code.add(TypeGen.formatLong((Long) val, false)); } else if (val instanceof Short) { - code.add(TypeGen.formatShort((Short) val)); + code.add(TypeGen.formatShort((Short) val, false)); } else if (val instanceof Byte) { - code.add(TypeGen.formatByte((Byte) val)); + code.add(TypeGen.formatByte((Byte) val, false)); } else if (val instanceof ArgType) { classGen.useType(code, (ArgType) val); code.add(".class"); diff --git a/jadx-core/src/main/java/jadx/core/codegen/InsnGen.java b/jadx-core/src/main/java/jadx/core/codegen/InsnGen.java index 4d8031949..f5687fb5a 100644 --- a/jadx-core/src/main/java/jadx/core/codegen/InsnGen.java +++ b/jadx-core/src/main/java/jadx/core/codegen/InsnGen.java @@ -132,7 +132,7 @@ public class InsnGen { } private String lit(LiteralArg arg) { - return TypeGen.literalToString(arg.getLiteral(), arg.getType(), mth, fallback); + return TypeGen.literalToString(arg, mth, fallback); } private void instanceField(CodeWriter code, FieldInfo field, InsnArg arg) throws CodegenException { diff --git a/jadx-core/src/main/java/jadx/core/codegen/TypeGen.java b/jadx-core/src/main/java/jadx/core/codegen/TypeGen.java index 1370ecb10..40a67d62c 100644 --- a/jadx-core/src/main/java/jadx/core/codegen/TypeGen.java +++ b/jadx-core/src/main/java/jadx/core/codegen/TypeGen.java @@ -4,7 +4,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import jadx.core.deobf.NameMapper; +import jadx.core.dex.attributes.AFlag; import jadx.core.dex.instructions.args.ArgType; +import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.instructions.args.PrimitiveType; import jadx.core.dex.nodes.IDexNode; import jadx.core.utils.StringUtils; @@ -28,16 +30,26 @@ public class TypeGen { return stype.getShortName(); } + /** + * Convert literal arg to string (preferred method) + */ + public static String literalToString(LiteralArg arg, IDexNode dexNode, boolean fallback) { + return literalToString(arg.getLiteral(), arg.getType(), + dexNode.root().getStringUtils(), + fallback, + arg.contains(AFlag.EXPLICIT_PRIMITIVE_TYPE)); + } + /** * Convert literal value to string according to value type * * @throws JadxRuntimeException for incorrect type or literal value */ public static String literalToString(long lit, ArgType type, IDexNode dexNode, boolean fallback) { - return literalToString(lit, type, dexNode.root().getStringUtils(), fallback); + return literalToString(lit, type, dexNode.root().getStringUtils(), fallback, false); } - public static String literalToString(long lit, ArgType type, StringUtils stringUtils, boolean fallback) { + public static String literalToString(long lit, ArgType type, StringUtils stringUtils, boolean fallback, boolean cast) { if (type == null || !type.isTypeKnown()) { String n = Long.toString(lit); if (fallback && Math.abs(lit) > 100) { @@ -65,13 +77,13 @@ public class TypeGen { } return stringUtils.unescapeChar(ch); case BYTE: - return formatByte(lit); + return formatByte(lit, cast); case SHORT: - return formatShort(lit); + return formatShort(lit, cast); case INT: - return formatInteger(lit); + return formatInteger(lit, cast); case LONG: - return formatLong(lit); + return formatLong(lit, cast); case FLOAT: return formatFloat(Float.intBitsToFloat((int) lit)); case DOUBLE: @@ -90,37 +102,40 @@ public class TypeGen { } } - public static String formatShort(long l) { + public static String formatShort(long l, boolean cast) { if (l == Short.MAX_VALUE) { return "Short.MAX_VALUE"; } if (l == Short.MIN_VALUE) { return "Short.MIN_VALUE"; } - return Long.toString(l); + String str = Long.toString(l); + return cast ? "(short) " + str : str; } - public static String formatByte(long l) { + public static String formatByte(long l, boolean cast) { if (l == Byte.MAX_VALUE) { return "Byte.MAX_VALUE"; } if (l == Byte.MIN_VALUE) { return "Byte.MIN_VALUE"; } - return Long.toString(l); + String str = Long.toString(l); + return cast ? "(byte) " + str : str; } - public static String formatInteger(long l) { + public static String formatInteger(long l, boolean cast) { if (l == Integer.MAX_VALUE) { return "Integer.MAX_VALUE"; } if (l == Integer.MIN_VALUE) { return "Integer.MIN_VALUE"; } - return Long.toString(l); + String str = Long.toString(l); + return cast ? "(int) " + str : str; } - public static String formatLong(long l) { + public static String formatLong(long l, boolean cast) { if (l == Long.MAX_VALUE) { return "Long.MAX_VALUE"; } @@ -128,8 +143,8 @@ public class TypeGen { return "Long.MIN_VALUE"; } String str = Long.toString(l); - if (Math.abs(l) >= Integer.MAX_VALUE) { - str += 'L'; + if (cast || Math.abs(l) >= Integer.MAX_VALUE) { + return str + 'L'; } return str; } diff --git a/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java b/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java index 2b2a549f7..27c2e3895 100644 --- a/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java +++ b/jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java @@ -53,5 +53,10 @@ public enum AFlag { EXPLICIT_GENERICS, + /** + * Use constants with explicit type: cast '(byte) 1' or type letter '7L' + */ + EXPLICIT_PRIMITIVE_TYPE, + INCONSISTENT_CODE, // warning about incorrect decompilation } diff --git a/jadx-core/src/main/java/jadx/core/dex/info/MethodInfo.java b/jadx-core/src/main/java/jadx/core/dex/info/MethodInfo.java index f46f62c17..a37c45dbc 100644 --- a/jadx-core/src/main/java/jadx/core/dex/info/MethodInfo.java +++ b/jadx-core/src/main/java/jadx/core/dex/info/MethodInfo.java @@ -35,13 +35,12 @@ public final class MethodInfo { private MethodInfo(ClassInfo declClass, String name, List args, ArgType retType) { this.name = name; - alias = name; - aliasFromPreset = false; + this.alias = name; + this.aliasFromPreset = false; this.declClass = declClass; - this.args = args; this.retType = retType; - shortId = makeSignature(true); + this.shortId = makeSignature(true); } public static MethodInfo externalMth(ClassInfo declClass, String name, List args, ArgType retType) { diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/CallMthInterface.java b/jadx-core/src/main/java/jadx/core/dex/instructions/CallMthInterface.java index 2fc72a3e2..f5649c305 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/CallMthInterface.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/CallMthInterface.java @@ -1,8 +1,11 @@ package jadx.core.dex.instructions; import jadx.core.dex.info.MethodInfo; +import jadx.core.dex.instructions.args.RegisterArg; public interface CallMthInterface { MethodInfo getCallMth(); + + RegisterArg getInstanceArg(); } diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeNode.java b/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeNode.java index 601ed9c40..ad3995037 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeNode.java @@ -1,10 +1,13 @@ package jadx.core.dex.instructions; +import org.jetbrains.annotations.Nullable; + import com.android.dx.io.instructions.DecodedInstruction; import jadx.core.dex.info.MethodInfo; import jadx.core.dex.instructions.args.ArgType; import jadx.core.dex.instructions.args.InsnArg; +import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.nodes.InsnNode; import jadx.core.utils.InsnUtils; import jadx.core.utils.Utils; @@ -51,6 +54,18 @@ public class InvokeNode extends InsnNode implements CallMthInterface { return mth; } + @Override + @Nullable + public RegisterArg getInstanceArg() { + if (type != InvokeType.STATIC && getArgsCount() > 0) { + InsnArg firstArg = getArg(0); + if (firstArg.isRegister()) { + return ((RegisterArg) firstArg); + } + } + return null; + } + @Override public InsnNode copy() { return copyCommonParams(new InvokeNode(mth, type, getArgsCount())); diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/args/LiteralArg.java b/jadx-core/src/main/java/jadx/core/dex/instructions/args/LiteralArg.java index 40a901215..b0a6c5baa 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/args/LiteralArg.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/args/LiteralArg.java @@ -77,7 +77,7 @@ public final class LiteralArg extends InsnArg { @Override public String toString() { try { - String value = TypeGen.literalToString(literal, getType(), DEF_STRING_UTILS, true); + String value = TypeGen.literalToString(literal, getType(), DEF_STRING_UTILS, true, false); if (getType().equals(ArgType.BOOLEAN) && (value.equals("true") || value.equals("false"))) { return value; } diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/args/RegisterArg.java b/jadx-core/src/main/java/jadx/core/dex/instructions/args/RegisterArg.java index ed03eb8d9..c536e0b32 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/args/RegisterArg.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/args/RegisterArg.java @@ -55,8 +55,7 @@ public class RegisterArg extends InsnArg implements Named { if (sVar != null) { return sVar.getTypeInfo().getType(); } - LOG.warn("Register type unknown, SSA variable not initialized: r{}", regNum); - return type; + return ArgType.UNKNOWN; } public ArgType getInitType() { diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/mods/ConstructorInsn.java b/jadx-core/src/main/java/jadx/core/dex/instructions/mods/ConstructorInsn.java index d4dc6b6bb..ba3c0f205 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/mods/ConstructorInsn.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/mods/ConstructorInsn.java @@ -63,6 +63,7 @@ public class ConstructorInsn extends InsnNode implements CallMthInterface { return callMth; } + @Override public RegisterArg getInstanceArg() { return instanceArg; } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/ConstInlineVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/ConstInlineVisitor.java index 98328da5f..02df79678 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/ConstInlineVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/ConstInlineVisitor.java @@ -40,6 +40,10 @@ public class ConstInlineVisitor extends AbstractVisitor { if (mth.isNoCode()) { return; } + process(mth); + } + + public static void process(MethodNode mth) { List toRemove = new ArrayList<>(); for (BlockNode block : mth.getBasicBlocks()) { toRemove.clear(); @@ -175,17 +179,19 @@ public class ConstInlineVisitor extends AbstractVisitor { if (constArg.isLiteral()) { long literal = ((LiteralArg) constArg).getLiteral(); - ArgType argType = arg.getInitType(); + ArgType argType = arg.getType(); + if (argType == ArgType.UNKNOWN) { + argType = arg.getInitType(); + } if (argType.isObject() && literal != 0) { argType = ArgType.NARROW_NUMBERS; } LiteralArg litArg = InsnArg.lit(literal, argType); + litArg.copyAttributesFrom(constArg); if (!useInsn.replaceArg(arg, litArg)) { return false; } // arg replaced, made some optimizations - litArg.setType(arg.getInitType()); - FieldNode fieldNode = null; ArgType litArgType = litArg.getType(); if (litArgType.isTypeKnown()) { diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/DeboxingVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/DeboxingVisitor.java new file mode 100644 index 000000000..5100082f0 --- /dev/null +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/DeboxingVisitor.java @@ -0,0 +1,149 @@ +package jadx.core.dex.visitors; + +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import jadx.core.dex.attributes.AFlag; +import jadx.core.dex.info.ClassInfo; +import jadx.core.dex.info.MethodInfo; +import jadx.core.dex.instructions.InsnType; +import jadx.core.dex.instructions.InvokeNode; +import jadx.core.dex.instructions.InvokeType; +import jadx.core.dex.instructions.args.ArgType; +import jadx.core.dex.instructions.args.InsnArg; +import jadx.core.dex.instructions.args.RegisterArg; +import jadx.core.dex.nodes.BlockNode; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.nodes.RootNode; +import jadx.core.dex.visitors.regions.variables.ProcessVariables; +import jadx.core.dex.visitors.shrink.CodeShrinkVisitor; +import jadx.core.utils.BlockUtils; +import jadx.core.utils.exceptions.JadxException; + +/** + * Remove primitives boxing + * i.e convert 'Integer.valueOf(1)' to '1' + */ +@JadxVisitor( + name = "DeboxingVisitor", + desc = "Remove primitives boxing", + runBefore = { + CodeShrinkVisitor.class, + ProcessVariables.class + } +) +public class DeboxingVisitor extends AbstractVisitor { + + private Set valueOfMths; + + @Override + public void init(RootNode root) { + valueOfMths = new HashSet<>(); + valueOfMths.add(valueOfMth(root, ArgType.INT, "java.lang.Integer")); + valueOfMths.add(valueOfMth(root, ArgType.BOOLEAN, "java.lang.Boolean")); + valueOfMths.add(valueOfMth(root, ArgType.BYTE, "java.lang.Byte")); + valueOfMths.add(valueOfMth(root, ArgType.SHORT, "java.lang.Short")); + valueOfMths.add(valueOfMth(root, ArgType.CHAR, "java.lang.Character")); + valueOfMths.add(valueOfMth(root, ArgType.LONG, "java.lang.Long")); + } + + private static MethodInfo valueOfMth(RootNode root, ArgType argType, String clsName) { + ArgType boxType = ArgType.object(clsName); + ClassInfo boxCls = ClassInfo.fromType(root, boxType); + return MethodInfo.externalMth(boxCls, "valueOf", Collections.singletonList(argType), boxType); + } + + @Override + public void visit(MethodNode mth) throws JadxException { + if (mth.isNoCode()) { + return; + } + boolean replaced = false; + for (BlockNode blockNode : mth.getBasicBlocks()) { + List insnList = blockNode.getInstructions(); + int count = insnList.size(); + for (int i = 0; i < count; i++) { + InsnNode insnNode = insnList.get(i); + if (insnNode.getType() == InsnType.INVOKE) { + InsnNode replaceInsn = checkForReplace(((InvokeNode) insnNode)); + if (replaceInsn != null) { + BlockUtils.replaceInsn(blockNode, i, replaceInsn); + replaced = true; + } + } + } + } + if (replaced) { + ConstInlineVisitor.process(mth); + } + } + + private InsnNode checkForReplace(InvokeNode insnNode) { + if (insnNode.getInvokeType() != InvokeType.STATIC + || insnNode.getResult() == null) { + return null; + } + MethodInfo callMth = insnNode.getCallMth(); + if (valueOfMths.contains(callMth)) { + RegisterArg resArg = insnNode.getResult(); + InsnArg arg = insnNode.getArg(0); + if (arg.isLiteral() && checkArgUsage(resArg)) { + ArgType primitiveType = callMth.getArgumentsTypes().get(0); + ArgType boxType = callMth.getReturnType(); + if (isNeedExplicitCast(resArg, primitiveType, boxType)) { + arg.add(AFlag.EXPLICIT_PRIMITIVE_TYPE); + } + resArg.setType(primitiveType); + arg.setType(primitiveType); + + InsnNode constInsn = new InsnNode(InsnType.CONST, 1); + constInsn.addArg(arg); + constInsn.setResult(resArg); + return constInsn; + } + } + return null; + } + + private boolean isNeedExplicitCast(RegisterArg resArg, ArgType primitiveType, ArgType boxType) { + if (primitiveType == ArgType.LONG) { + return true; + } + if (primitiveType != ArgType.INT) { + Set useTypes = collectUseTypes(resArg); + useTypes.add(resArg.getType()); + useTypes.remove(boxType); + useTypes.remove(primitiveType); + return !useTypes.isEmpty(); + } + return false; + } + + private boolean checkArgUsage(RegisterArg arg) { + for (RegisterArg useArg : arg.getSVar().getUseList()) { + InsnNode parentInsn = useArg.getParentInsn(); + if (parentInsn == null) { + return false; + } + if (parentInsn.getType() == InsnType.INVOKE) { + InvokeNode invokeNode = (InvokeNode) parentInsn; + if (useArg.equals(invokeNode.getInstanceArg())) { + return false; + } + } + } + return true; + } + + private Set collectUseTypes(RegisterArg arg) { + Set types = new HashSet<>(); + for (RegisterArg useArg : arg.getSVar().getUseList()) { + types.add(useArg.getType()); + types.add(useArg.getInitType()); + } + return types; + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/invoke/TestVarArg.java b/jadx-core/src/test/java/jadx/tests/integration/invoke/TestVarArg.java index 3ade2867e..5ca48e23f 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/invoke/TestVarArg.java +++ b/jadx-core/src/test/java/jadx/tests/integration/invoke/TestVarArg.java @@ -37,7 +37,7 @@ public class TestVarArg extends IntegrationTest { assertThat(code, containsString("void test2(int i, Object... a) {")); assertThat(code, containsString("test1(1, 2);")); - assertThat(code, containsString("test2(3, \"1\", Integer.valueOf(7));")); + assertThat(code, containsString("test2(3, \"1\", 7);")); // negative case assertThat(code, containsString("void test3(int[] a) {")); diff --git a/jadx-core/src/test/java/jadx/tests/integration/others/TestDeboxing.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestDeboxing.java new file mode 100644 index 000000000..efac96999 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestDeboxing.java @@ -0,0 +1,77 @@ +package jadx.tests.integration.others; + +import org.junit.jupiter.api.Test; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.JadxMatchers.containsOne; +import static jadx.tests.api.utils.JadxMatchers.countString; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class TestDeboxing extends IntegrationTest { + + public static class TestCls { + public Object testInt() { + return 1; + } + + public Object testBoolean() { + return true; + } + + public Object testByte() { + return (byte) 2; + } + + public Short testShort() { + return 3; + } + + public Character testChar() { + return 'c'; + } + + public Long testLong() { + return 4L; + } + + public void testConstInline() { + Boolean v = true; + use(v); + use(v); + } + + private void use(Boolean v) { + } + + public void check() { + // don't mind weird comparisons + // need to get primitive without using boxing or literal + // otherwise will get same result after decompilation + assertThat(testInt(), is(Integer.sum(0, 1))); + assertThat(testBoolean(), is(Boolean.TRUE)); + assertThat(testByte(), is(Byte.parseByte("2"))); + assertThat(testShort(), is(Short.parseShort("3"))); + assertThat(testChar(), is("c".charAt(0))); + assertThat(testLong(), is(Long.valueOf("4"))); + testConstInline(); + } + } + + @Test + public void test() { + noDebugInfo(); + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + + assertThat(code, containsOne("return 1;")); + assertThat(code, containsOne("return true;")); + assertThat(code, containsOne("return (byte) 2;")); + assertThat(code, containsOne("return 3;")); + assertThat(code, containsOne("return 'c';")); + assertThat(code, containsOne("return 4L;")); + assertThat(code, countString(2, "use(true);")); + } +}