diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java b/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java index ac68d653a..9ee9a729b 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/ClassNode.java @@ -12,6 +12,7 @@ import jadx.core.dex.info.ClassInfo; import jadx.core.dex.info.FieldInfo; import jadx.core.dex.info.MethodInfo; import jadx.core.dex.instructions.args.ArgType; +import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.nodes.parser.AnnotationsParser; import jadx.core.dex.nodes.parser.FieldValueAttr; import jadx.core.dex.nodes.parser.StaticValuesParser; @@ -251,6 +252,23 @@ public class ClassNode extends LineAttrNode implements ILoadable { return field; } + public FieldNode getConstFieldByLiteralArg(LiteralArg arg) { + ArgType type = arg.getType(); + long literal = arg.getLiteral(); + + if (type.equals(ArgType.DOUBLE)) + return getConstField(Double.longBitsToDouble(literal)); + else if (type.equals(ArgType.FLOAT)) + return getConstField(Float.intBitsToFloat((int) literal)); + else if (Math.abs(literal) > 0x1) { + if (type.equals(ArgType.INT)) + return getConstField((int) literal); + else if (type.equals(ArgType.LONG)) + return getConstField(literal); + } + return null; + } + public FieldNode searchFieldById(int id) { String name = FieldInfo.getNameById(dex, id); for (FieldNode f : fields) { diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/BlockMakerVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/BlockMakerVisitor.java index c79ecb782..220bdbb1d 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/BlockMakerVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/BlockMakerVisitor.java @@ -349,44 +349,70 @@ public class BlockMakerVisitor extends AbstractVisitor { return true; } } - - // splice return block if several precessors presents - if (false && block.getAttributes().contains(AttributeFlag.RETURN) - && block.getPredecessors().size() > 1 - && !block.getInstructions().get(0).getAttributes().contains(AttributeType.CATCH_BLOCK)) { + } + // splice return block if several predecessors presents + for (BlockNode block : mth.getExitBlocks()) { + if (block.getInstructions().size() == 1 + && block.getInstructions().get(0).getArgsCount() > 0 + && !block.getInstructions().get(0).getAttributes().contains(AttributeType.CATCH_BLOCK) + && !block.getAttributes().contains(AttributeFlag.SYNTHETIC)) { List preds = new ArrayList(block.getPredecessors()); - - BlockNode origRetBlock = block; - origRetBlock.getPredecessors().clear(); - origRetBlock.getPredecessors().add(preds.get(0)); - preds.remove(0); - - InsnNode origReturnInsn = origRetBlock.getInstructions().get(0); + InsnNode origReturnInsn = block.getInstructions().get(0); RegisterArg retArg = null; if (origReturnInsn.getArgsCount() != 0) retArg = (RegisterArg) origReturnInsn.getArg(0); for (BlockNode pred : preds) { - pred.getSuccessors().remove(origRetBlock); - // make copy of return block and connect to predecessor - BlockNode newRetBlock = startNewBlock(mth, origRetBlock.getStartOffset()); + BlockNode newRetBlock; + InsnNode predInsn = pred.getInstructions().get(0); + + switch (predInsn.getType()) { + case IF: + // make copy of return block and connect to predecessor + newRetBlock = startNewBlock(mth, block.getStartOffset()); + newRetBlock.getAttributes().add(AttributeFlag.SYNTHETIC); + + if (pred.getSuccessors().get(0) == block) { + pred.getSuccessors().set(0, newRetBlock); + } else if (pred.getSuccessors().get(1) == block){ + pred.getSuccessors().set(1, newRetBlock); + } + block.getPredecessors().remove(pred); + newRetBlock.getPredecessors().add(pred); + break; + + case SWITCH: + // TODO: is it ok to just skip this predecessor? + block.getAttributes().add(AttributeFlag.SYNTHETIC); + continue; + + default: + removeConnection(pred, block); + newRetBlock = pred; + break; + } InsnNode ret = new InsnNode(InsnType.RETURN, 1); - if (retArg != null) + if (retArg != null) { ret.addArg(InsnArg.reg(retArg.getRegNum(), retArg.getType())); + ret.getArg(0).forceSetTypedVar(retArg.getTypedVar()); + } ret.getAttributes().addAll(origReturnInsn.getAttributes()); newRetBlock.getInstructions().add(ret); newRetBlock.getAttributes().add(AttributeFlag.RETURN); - connect(pred, newRetBlock); mth.addExitBlock(newRetBlock); } - return true; + if (block.getPredecessors().size() == 0) { + mth.getBasicBlocks().remove(block); + mth.getExitBlocks().remove(block); + return true; + } + return block.getAttributes().contains(AttributeFlag.SYNTHETIC); } - - // TODO detect ternary operator } + // TODO detect ternary operator return false; } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/CodeShrinker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/CodeShrinker.java index f58a8ece6..7340668c1 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/CodeShrinker.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/CodeShrinker.java @@ -81,13 +81,12 @@ public class CodeShrinker extends AbstractVisitor { } } if (wrap) { -// if (useInsn.getType() == InsnType.MOVE) { -// // TODO -// // remover.add(useInsn); -// } else { - useInsnArg.wrapInstruction(insn); + if (insn.getType() == InsnType.MOVE) { + useInsnArg.getParentInsn().setArg(0, insn.getArg(0)); + } else { + useInsnArg.wrapInstruction(insn); + } remover.add(insn); -// } } } } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/ModVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/ModVisitor.java index fd5602dad..8cfea235f 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/ModVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/ModVisitor.java @@ -3,23 +3,10 @@ package jadx.core.dex.visitors; import jadx.core.deobf.NameMapper; import jadx.core.dex.attributes.AttributeType; import jadx.core.dex.info.MethodInfo; -import jadx.core.dex.instructions.ConstClassNode; -import jadx.core.dex.instructions.ConstStringNode; -import jadx.core.dex.instructions.FillArrayNode; -import jadx.core.dex.instructions.IndexInsnNode; -import jadx.core.dex.instructions.InsnType; -import jadx.core.dex.instructions.InvokeNode; -import jadx.core.dex.instructions.SwitchNode; -import jadx.core.dex.instructions.args.ArgType; -import jadx.core.dex.instructions.args.InsnArg; -import jadx.core.dex.instructions.args.LiteralArg; -import jadx.core.dex.instructions.args.RegisterArg; +import jadx.core.dex.instructions.*; +import jadx.core.dex.instructions.args.*; import jadx.core.dex.instructions.mods.ConstructorInsn; -import jadx.core.dex.nodes.BlockNode; -import jadx.core.dex.nodes.ClassNode; -import jadx.core.dex.nodes.FieldNode; -import jadx.core.dex.nodes.InsnNode; -import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.nodes.*; import jadx.core.dex.trycatch.ExcHandlerAttr; import jadx.core.dex.trycatch.ExceptionHandler; import jadx.core.utils.BlockUtils; @@ -60,6 +47,8 @@ public class ModVisitor extends AbstractVisitor { int size = block.getInstructions().size(); for (int i = 0; i < size; i++) { InsnNode insn = block.getInstructions().get(i); + ClassNode parentClass = mth.getParentClass(); + FieldNode f = null; switch (insn.getType()) { case INVOKE: @@ -103,8 +92,6 @@ public class ModVisitor extends AbstractVisitor { case CONST: case CONST_STR: case CONST_CLASS: - ClassNode parentClass = mth.getParentClass(); - FieldNode f = null; if (insn.getType() == InsnType.CONST_STR) { String s = ((ConstStringNode) insn).getString(); f = parentClass.getConstField(s); @@ -112,19 +99,7 @@ public class ModVisitor extends AbstractVisitor { ArgType t = ((ConstClassNode) insn).getClsType(); f = parentClass.getConstField(t); } else { - LiteralArg arg = (LiteralArg) insn.getArg(0); - ArgType type = arg.getType(); - long lit = arg.getLiteral(); - if (type.equals(ArgType.DOUBLE)) - f = parentClass.getConstField(Double.longBitsToDouble(lit)); - else if (type.equals(ArgType.FLOAT)) - f = parentClass.getConstField(Float.intBitsToFloat((int) lit)); - else if (Math.abs(lit) > 0x1) { - if (type.equals(ArgType.INT)) - f = parentClass.getConstField((int) lit); - else if (type.equals(ArgType.LONG)) - f = parentClass.getConstField(lit); - } + f = parentClass.getConstFieldByLiteralArg((LiteralArg) insn.getArg(0)); } if (f != null) { InsnNode inode = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0); @@ -135,17 +110,25 @@ public class ModVisitor extends AbstractVisitor { case SWITCH: SwitchNode sn = (SwitchNode) insn; - parentClass = mth.getParentClass(); - f = null; for (int k = 0; k < sn.getCasesCount(); k++) { - f = parentClass.getConstField((Integer) sn.getKeys()[k]); + f = parentClass.getConstField(sn.getKeys()[k]); if (f != null) { - InsnNode inode = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0); - sn.getKeys()[k] = inode; + sn.getKeys()[k] = new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0); } } break; - + + case RETURN: + if (insn.getArgsCount() > 0 + && insn.getArg(0).isLiteral()) { + LiteralArg arg = (LiteralArg) insn.getArg(0); + f = parentClass.getConstFieldByLiteralArg(arg); + if (f != null) { + arg.wrapInstruction(new IndexInsnNode(InsnType.SGET, f.getFieldInfo(), 0)); + } + } + break; + default: break; } diff --git a/jadx-core/src/test/java/jadx/tests/internal/TestReturnWrapping.java b/jadx-core/src/test/java/jadx/tests/internal/TestReturnWrapping.java new file mode 100644 index 000000000..ec04fb864 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/internal/TestReturnWrapping.java @@ -0,0 +1,63 @@ +package jadx.tests.internal; + +import jadx.api.InternalJadxTest; +import jadx.core.dex.nodes.ClassNode; + +import org.junit.Test; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertThat; + +public class TestReturnWrapping extends InternalJadxTest { + public static class TestCls { + /**/ + public static int f1(int arg0) { + switch (arg0) { + case 1: + return 255; + } + return arg0 + 1; + }/**/ + + /**/ + public static Object f2(Object arg0, int arg1) { + Object ret = null; + int i = arg1; + if (arg0 == null) { + return ret + Integer.toHexString(i); + } else { + i++; + try { + ret = new Object().getClass(); + } catch (Exception e) { + ret = "Qwerty"; + } + return i > 128 ? arg0.toString() + ret.toString() : i; + } + }/**/ + + /**/ + public static int f3(int arg0) { + while (arg0 > 10) { + int abc = 951; + if (arg0 == 255) { + return arg0 + 2; + } + arg0 -= abc; + } + return arg0; + }/**/ + } + + @Test + public void test() { + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + assertThat(code, containsString("return 255;")); + assertThat(code, containsString("return arg0 + 1;")); + //assertThat(code, containsString("return Integer.toHexString(i);")); + assertThat(code, containsString("return arg0.toString() + ret.toString();")); + assertThat(code, containsString("return arg0 + 2;")); + assertThat(code, containsString("arg0 -= 951;")); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/internal/TestSwitchLabels.java b/jadx-core/src/test/java/jadx/tests/internal/TestSwitchLabels.java new file mode 100644 index 000000000..e88df5539 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/internal/TestSwitchLabels.java @@ -0,0 +1,49 @@ +package jadx.tests.internal; + +import jadx.api.InternalJadxTest; +import jadx.core.dex.nodes.ClassNode; + +import org.junit.Test; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.not; +import static org.junit.Assert.assertThat; +import static org.mockito.AdditionalMatchers.or; + +public class TestSwitchLabels extends InternalJadxTest { + public static class TestCls { + public static final int CONST_ABC = 0xABC; + public static final int CONST_CDE = 0xCDE; + + public static class Inner { + private static final int CONST_CDE_PRIVATE = 0xCDE; + public int f1(int arg0) { + switch (arg0) { + case CONST_CDE_PRIVATE: + return CONST_ABC; + } + return 0; + } + } + + public static int f1(int arg0) { + switch (arg0) { + case CONST_ABC: + return CONST_CDE; + } + return 0; + } + } + + @Test + public void test() { + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + assertThat(code, containsString("case CONST_ABC:")); + assertThat(code, containsString("return CONST_CDE;")); + + cls.addInnerClass(getClassNode(TestCls.Inner.class)); + assertThat(code, containsString("case CONST_CDE_PRIVATE:")); + assertThat(code, containsString(".CONST_ABC;")); + } +}