From 22fa1321106ca412e54ba26325d07e287f322647 Mon Sep 17 00:00:00 2001 From: Skylot Date: Mon, 1 Feb 2021 18:37:13 +0000 Subject: [PATCH] fix: support instance invoke for 'invoke-custom' instruction (#384) --- .../main/java/jadx/core/codegen/InsnGen.java | 57 ++++++-- .../dex/instructions/InvokeCustomBuilder.java | 134 +++++++++++------- .../dex/instructions/InvokeCustomNode.java | 9 ++ .../integration/java8/TestLambdaInstance.java | 70 +++++++++ .../integration/java8/TestLambdaStatic.java | 3 +- 5 files changed, 210 insertions(+), 63 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/java8/TestLambdaInstance.java diff --git a/jadx-core/src/main/java/jadx/core/codegen/InsnGen.java b/jadx-core/src/main/java/jadx/core/codegen/InsnGen.java index eff0acd2d..27c0572bd 100644 --- a/jadx-core/src/main/java/jadx/core/codegen/InsnGen.java +++ b/jadx-core/src/main/java/jadx/core/codegen/InsnGen.java @@ -9,6 +9,7 @@ import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import jadx.api.plugins.input.data.MethodHandleType; import jadx.core.deobf.NameMapper; import jadx.core.dex.attributes.AFlag; import jadx.core.dex.attributes.AType; @@ -35,16 +36,29 @@ import jadx.core.dex.instructions.InvokeNode; import jadx.core.dex.instructions.InvokeType; import jadx.core.dex.instructions.NewArrayNode; import jadx.core.dex.instructions.SwitchInsn; -import jadx.core.dex.instructions.args.*; +import jadx.core.dex.instructions.args.ArgType; +import jadx.core.dex.instructions.args.CodeVar; +import jadx.core.dex.instructions.args.InsnArg; +import jadx.core.dex.instructions.args.InsnWrapArg; +import jadx.core.dex.instructions.args.LiteralArg; +import jadx.core.dex.instructions.args.Named; +import jadx.core.dex.instructions.args.NamedArg; +import jadx.core.dex.instructions.args.RegisterArg; +import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.mods.ConstructorInsn; import jadx.core.dex.instructions.mods.TernaryInsn; -import jadx.core.dex.nodes.*; +import jadx.core.dex.nodes.ClassNode; +import jadx.core.dex.nodes.FieldNode; +import jadx.core.dex.nodes.InsnNode; +import jadx.core.dex.nodes.MethodNode; +import jadx.core.dex.nodes.RootNode; +import jadx.core.dex.nodes.VariableNode; import jadx.core.utils.CodeGenUtils; import jadx.core.utils.RegionUtils; import jadx.core.utils.exceptions.CodegenException; import jadx.core.utils.exceptions.JadxRuntimeException; -import static jadx.core.dex.nodes.VariableNode.*; +import static jadx.core.dex.nodes.VariableNode.VarKind; import static jadx.core.utils.android.AndroidResourcesUtils.handleAppResField; public class InsnGen { @@ -765,6 +779,10 @@ public class InsnGen { } private void makeInvokeLambda(CodeWriter code, InvokeCustomNode customNode) throws CodegenException { + if (customNode.isUseRef()) { + makeRefLambda(code, customNode); + return; + } if (fallback || !customNode.isInlineInsn()) { makeSimpleLambda(code, customNode); return; @@ -773,6 +791,17 @@ public class InsnGen { makeInlinedLambdaMethod(code, customNode, callMth); } + private void makeRefLambda(CodeWriter code, InvokeCustomNode customNode) { + InvokeNode invokeInsn = (InvokeNode) customNode.getCallInsn(); + MethodInfo callMth = invokeInsn.getCallMth(); + if (customNode.getHandleType() == MethodHandleType.INVOKE_STATIC) { + useClass(code, callMth.getDeclClass()); + } else { + code.add("this"); + } + code.add("::").add(callMth.getAlias()); + } + private void makeSimpleLambda(CodeWriter code, InvokeCustomNode customNode) { try { InsnNode callInsn = customNode.getCallInsn(); @@ -782,17 +811,22 @@ public class InsnGen { code.add("()"); } else { code.add('('); - // rename lambda args int callArgsCount = callInsn.getArgsCount(); int startArg = callArgsCount - implArgsCount; - if (startArg < 0) { - System.out.println(); + if (customNode.getHandleType() != MethodHandleType.INVOKE_STATIC + && customNode.getArgsCount() > 0 + && customNode.getArg(0).isThis()) { + callInsn.getArg(0).add(AFlag.THIS); } - for (int i = startArg; i < callArgsCount; i++) { - if (i != startArg) { - code.add(", "); + if (startArg >= 0) { + for (int i = startArg; i < callArgsCount; i++) { + if (i != startArg) { + code.add(", "); + } + addArg(code, callInsn.getArg(i)); } - addArg(code, callInsn.getArg(i)); + } else { + code.add("/* ERROR: " + startArg + " */"); } code.add(')'); } @@ -837,7 +871,8 @@ public class InsnGen { } // force set external arg names into call method args int extArgsCount = customNode.getArgsCount(); - for (int i = 0; i < extArgsCount; i++) { + int startArg = customNode.getHandleType() == MethodHandleType.INVOKE_STATIC ? 0 : 1; // skip 'this' arg + for (int i = startArg; i < extArgsCount; i++) { RegisterArg extArg = (RegisterArg) customNode.getArg(i); callArgs.get(i).setName(extArg.getName()); } diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeCustomBuilder.java b/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeCustomBuilder.java index fd1bca23b..e7e0d61a6 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeCustomBuilder.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeCustomBuilder.java @@ -2,6 +2,8 @@ package jadx.core.dex.instructions; import java.util.List; +import org.jetbrains.annotations.NotNull; + import jadx.api.plugins.input.data.ICallSite; import jadx.api.plugins.input.data.IMethodHandle; import jadx.api.plugins.input.data.IMethodProto; @@ -18,6 +20,7 @@ import jadx.core.dex.instructions.args.NamedArg; import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.RootNode; +import jadx.core.utils.Utils; import jadx.core.utils.exceptions.JadxRuntimeException; public class InvokeCustomBuilder { @@ -31,64 +34,93 @@ public class InvokeCustomBuilder { throw new JadxRuntimeException("Failed to process invoke-custom instruction: " + callSite); } IMethodHandle callMthHandle = (IMethodHandle) values.get(4).getValue(); - MethodHandleType methodHandleType = callMthHandle.getType(); - if (methodHandleType.isField()) { + if (callMthHandle.getType().isField()) { throw new JadxRuntimeException("Not yet supported"); } - RootNode root = mth.root(); - IMethodProto lambdaProto = (IMethodProto) values.get(2).getValue(); - MethodInfo lambdaInfo = MethodInfo.fromMethodProto(root, mth.getParentClass().getClassInfo(), "", lambdaProto); - - InvokeCustomNode invokeCustomNode = new InvokeCustomNode(lambdaInfo, insn, false, isRange); - invokeCustomNode.setHandleType(methodHandleType); - - ClassInfo implCls = ClassInfo.fromType(root, lambdaInfo.getReturnType()); - String implName = (String) values.get(1).getValue(); - IMethodProto implProto = (IMethodProto) values.get(3).getValue(); - invokeCustomNode.setImplMthInfo(MethodInfo.fromMethodProto(root, implCls, implName, implProto)); - - MethodInfo callMthInfo = MethodInfo.fromRef(root, callMthHandle.getMethodRef()); - - InvokeType invokeType = convertInvokeType(methodHandleType); - int callArgsCount = callMthInfo.getArgsCount(); - InvokeNode callInsn = new InvokeNode(callMthInfo, invokeType, callArgsCount); - invokeCustomNode.setCallInsn(callInsn); - - // copy insn args - int argsCount = invokeCustomNode.getArgsCount(); - for (int i = 0; i < argsCount; i++) { - InsnArg arg = invokeCustomNode.getArg(i); - callInsn.addArg(arg.duplicate()); - } - if (callArgsCount > argsCount) { - // fill remaining args with NamedArg - for (int i = argsCount; i < callArgsCount; i++) { - ArgType argType = callMthInfo.getArgumentsTypes().get(i); - callInsn.addArg(new NamedArg("v" + i, argType)); - } - } - - MethodNode callMth = root.resolveMethod(callMthInfo); - if (callMth != null) { - callInsn.addAttr(callMth); - if (callMth.getAccessFlags().isSynthetic() - && callMth.getUseIn().size() <= 1 - && callMth.getParentClass().equals(mth.getParentClass())) { - // inline only synthetic methods from same class - callMth.add(AFlag.DONT_GENERATE); - invokeCustomNode.setInlineInsn(true); - } - } - // prevent args inlining into not generated invoke custom node - for (InsnArg arg : invokeCustomNode.getArguments()) { - arg.add(AFlag.DONT_INLINE); - } - return invokeCustomNode; + return buildMethodCall(mth, insn, isRange, values, callMthHandle); } catch (Exception e) { throw new JadxRuntimeException("'invoke-custom' instruction processing error: " + e.getMessage(), e); } } + @NotNull + private static InvokeCustomNode buildMethodCall(MethodNode mth, InsnData insn, boolean isRange, + List values, IMethodHandle callMthHandle) { + RootNode root = mth.root(); + IMethodProto lambdaProto = (IMethodProto) values.get(2).getValue(); + MethodInfo lambdaInfo = MethodInfo.fromMethodProto(root, mth.getParentClass().getClassInfo(), "", lambdaProto); + + MethodHandleType methodHandleType = callMthHandle.getType(); + InvokeCustomNode invokeCustomNode = new InvokeCustomNode(lambdaInfo, insn, false, isRange); + invokeCustomNode.setHandleType(methodHandleType); + + ClassInfo implCls = ClassInfo.fromType(root, lambdaInfo.getReturnType()); + String implName = (String) values.get(1).getValue(); + IMethodProto implProto = (IMethodProto) values.get(3).getValue(); + MethodInfo implMthInfo = MethodInfo.fromMethodProto(root, implCls, implName, implProto); + invokeCustomNode.setImplMthInfo(implMthInfo); + + MethodInfo callMthInfo = MethodInfo.fromRef(root, callMthHandle.getMethodRef()); + + InvokeType invokeType = convertInvokeType(methodHandleType); + int callArgsCount = callMthInfo.getArgsCount(); + boolean instanceCall = invokeType != InvokeType.STATIC; + if (instanceCall) { + callArgsCount++; + } + InvokeNode callInsn = new InvokeNode(callMthInfo, invokeType, callArgsCount); + invokeCustomNode.setCallInsn(callInsn); + + // copy insn args + int argsCount = invokeCustomNode.getArgsCount(); + for (int i = 0; i < argsCount; i++) { + InsnArg arg = invokeCustomNode.getArg(i); + callInsn.addArg(arg.duplicate()); + } + if (callArgsCount > argsCount) { + // fill remaining args with NamedArg + int callArgNum = argsCount; + if (instanceCall) { + callArgNum--; // start from instance type + } + List callArgTypes = callMthInfo.getArgumentsTypes(); + for (int i = argsCount; i < callArgsCount; i++) { + ArgType argType; + if (callArgNum < 0) { + // instance arg type + argType = callMthInfo.getDeclClass().getType(); + } else { + argType = callArgTypes.get(callArgNum++); + } + callInsn.addArg(new NamedArg("v" + i, argType)); + } + } + + MethodNode callMth = root.resolveMethod(callMthInfo); + if (callMth != null) { + callInsn.addAttr(callMth); + if (callMth.getAccessFlags().isSynthetic() + && callMth.getUseIn().size() <= 1 + && callMth.getParentClass().equals(mth.getParentClass())) { + // inline only synthetic methods from same class + callMth.add(AFlag.DONT_GENERATE); + invokeCustomNode.setInlineInsn(true); + } + } + if (!invokeCustomNode.isInlineInsn()) { + IMethodProto effectiveMthProto = (IMethodProto) values.get(5).getValue(); + List args = Utils.collectionMap(effectiveMthProto.getArgTypes(), ArgType::parse); + boolean sameArgs = args.equals(callMthInfo.getArgumentsTypes()); + invokeCustomNode.setUseRef(sameArgs); + } + + // prevent args inlining into not generated invoke custom node + for (InsnArg arg : invokeCustomNode.getArguments()) { + arg.add(AFlag.DONT_INLINE); + } + return invokeCustomNode; + } + /** * Expect LambdaMetafactory.metafactory method */ diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeCustomNode.java b/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeCustomNode.java index 40ab67839..c308eed8e 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeCustomNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/InvokeCustomNode.java @@ -14,6 +14,7 @@ public class InvokeCustomNode extends InvokeNode { private MethodHandleType handleType; private InsnNode callInsn; private boolean inlineInsn; + private boolean useRef; public InvokeCustomNode(MethodInfo lambdaInfo, InsnData insn, boolean instanceCall, boolean isRange) { super(lambdaInfo, insn, InvokeType.CUSTOM, instanceCall, isRange); @@ -51,6 +52,14 @@ public class InvokeCustomNode extends InvokeNode { this.inlineInsn = inlineInsn; } + public boolean isUseRef() { + return useRef; + } + + public void setUseRef(boolean useRef) { + this.useRef = useRef; + } + @Nullable public BaseInvokeNode getInvokeCall() { if (callInsn.getType() == InsnType.INVOKE) { diff --git a/jadx-core/src/test/java/jadx/tests/integration/java8/TestLambdaInstance.java b/jadx-core/src/test/java/jadx/tests/integration/java8/TestLambdaInstance.java new file mode 100644 index 000000000..c6ba4eb2e --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/java8/TestLambdaInstance.java @@ -0,0 +1,70 @@ +package jadx.tests.integration.java8; + +import java.util.function.Function; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestLambdaInstance extends IntegrationTest { + + @SuppressWarnings("Convert2MethodRef") + public static class TestCls { + + public Function test() { + return str -> this.call(str); + } + + public Function testMthRef() { + return this::call; + } + + public Integer call(String str) { + return Integer.parseInt(str); + } + + public Function test2() { + return num -> num.toString(); + } + + public Function testMthRef2() { + return Object::toString; + } + + public void check() throws Exception { + assertThat(test().apply("11")).isEqualTo(11); + assertThat(testMthRef().apply("7")).isEqualTo(7); + + assertThat(test2().apply(15)).isEqualTo("15"); + assertThat(testMthRef2().apply(13)).isEqualTo("13"); + } + } + + @Test + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .doesNotContain("lambda$") + .doesNotContain("renamed") + .containsLines(2, + "return str -> {", + indent() + "return call(str);", + "};") + // .containsOne("return Object::toString;") // TODO + .containsOne("return this::call;"); + } + + @Test + public void testNoDebug() { + noDebugInfo(); + getClassNode(TestCls.class); + } + + @Test + public void testFallback() { + setFallback(); + getClassNode(TestCls.class); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/java8/TestLambdaStatic.java b/jadx-core/src/test/java/jadx/tests/integration/java8/TestLambdaStatic.java index e3e739301..257f07121 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/java8/TestLambdaStatic.java +++ b/jadx-core/src/test/java/jadx/tests/integration/java8/TestLambdaStatic.java @@ -61,7 +61,8 @@ public class TestLambdaStatic extends IntegrationTest { .containsLines(2, "return () -> {", indent() + "return str;", - "};"); + "};") + .containsOne("return Integer::parseInt;"); } @Test