From 249801880c6ff8877fbbd422ef225f222c673dcb Mon Sep 17 00:00:00 2001 From: Skylot <118523+skylot@users.noreply.github.com> Date: Thu, 17 Oct 2024 19:20:07 +0100 Subject: [PATCH] feat(api): allow to get method code (#2305) --- .../src/main/java/jadx/api/JavaMethod.java | 17 ++--- .../main/java/jadx/api/utils/CodeUtils.java | 75 +++++++++++++++++++ .../java/jadx/core/dex/nodes/MethodNode.java | 9 +++ .../api/utils/assertj/JadxAssertions.java | 6 ++ .../assertj/JadxMethodNodeAssertions.java | 20 +++++ .../jadx/tests/external/BaseExternalTest.java | 53 +------------ .../integration/others/TestConstReplace.java | 6 +- 7 files changed, 124 insertions(+), 62 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxMethodNodeAssertions.java diff --git a/jadx-core/src/main/java/jadx/api/JavaMethod.java b/jadx-core/src/main/java/jadx/api/JavaMethod.java index 22d4a5460..bfc5f96fa 100644 --- a/jadx-core/src/main/java/jadx/api/JavaMethod.java +++ b/jadx-core/src/main/java/jadx/api/JavaMethod.java @@ -2,7 +2,6 @@ package jadx.api; import java.util.Collections; import java.util.List; -import java.util.Objects; import java.util.stream.Collectors; import org.jetbrains.annotations.ApiStatus; @@ -79,15 +78,9 @@ public final class JavaMethod implements JavaNode { return Collections.emptyList(); } JadxDecompiler decompiler = getDeclaringClass().getRootDecompiler(); - return ovrdAttr.getRelatedMthNodes().stream() - .map(m -> { - JavaMethod javaMth = decompiler.convertMethodNode(m); - if (javaMth == null) { - LOG.warn("Failed convert to java method: {}", m); - } - return javaMth; - }) - .filter(Objects::nonNull) + return ovrdAttr.getRelatedMthNodes() + .stream() + .map(decompiler::convertMethodNode) .collect(Collectors.toList()); } @@ -104,6 +97,10 @@ public final class JavaMethod implements JavaNode { return mth.getDefPosition(); } + public String getCodeStr() { + return mth.getCodeStr(); + } + @Override public void removeAlias() { this.mth.getMethodInfo().removeAlias(); diff --git a/jadx-core/src/main/java/jadx/api/utils/CodeUtils.java b/jadx-core/src/main/java/jadx/api/utils/CodeUtils.java index 9b7730303..7cd9d3408 100644 --- a/jadx-core/src/main/java/jadx/api/utils/CodeUtils.java +++ b/jadx-core/src/main/java/jadx/api/utils/CodeUtils.java @@ -1,5 +1,13 @@ package jadx.api.utils; +import java.util.function.BiFunction; + +import jadx.api.ICodeInfo; +import jadx.api.metadata.ICodeAnnotation; +import jadx.api.metadata.ICodeNodeRef; +import jadx.api.metadata.annotations.NodeDeclareRef; +import jadx.core.dex.nodes.MethodNode; + public class CodeUtils { public static String getLineForPos(String code, int pos) { @@ -47,4 +55,71 @@ public class CodeUtils { line++; } } + + /** + * Cut method code (including comments and annotations) from class code. + * + * @return method code or empty string if metadata is not available + */ + public static String extractMethodCode(MethodNode mth, ICodeInfo codeInfo) { + int end = getMethodEnd(mth, codeInfo); + if (end == -1) { + return ""; + } + int start = getMethodStart(mth, codeInfo); + if (end < start) { + return ""; + } + return codeInfo.getCodeStr().substring(start, end); + } + + /** + * Search first empty line before method definition to include comments and annotations + */ + private static int getMethodStart(MethodNode mth, ICodeInfo codeInfo) { + int pos = mth.getDefPosition(); + String newLineStr = mth.root().getArgs().getCodeNewLineStr(); + String emptyLine = newLineStr + newLineStr; + int emptyLinePos = codeInfo.getCodeStr().lastIndexOf(emptyLine, pos); + return emptyLinePos == -1 ? pos : emptyLinePos + emptyLine.length(); + } + + /** + * Search method end position in provided class code info. + * + * @return end pos or -1 if metadata not available + */ + public static int getMethodEnd(MethodNode mth, ICodeInfo codeInfo) { + if (!codeInfo.hasMetadata()) { + return -1; + } + // skip nested nodes DEF/END until first unpaired END annotation (end of this method) + Integer end = codeInfo.getCodeMetadata().searchDown(mth.getDefPosition() + 1, new BiFunction<>() { + int nested = 0; + + @Override + public Integer apply(Integer pos, ICodeAnnotation ann) { + switch (ann.getAnnType()) { + case DECLARATION: + ICodeNodeRef node = ((NodeDeclareRef) ann).getNode(); + switch (node.getAnnType()) { + case CLASS: + case METHOD: + nested++; + break; + } + break; + + case END: + if (nested == 0) { + return pos; + } + nested--; + break; + } + return null; + } + }); + return end == null ? -1 : end; + } } diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/MethodNode.java b/jadx-core/src/main/java/jadx/core/dex/nodes/MethodNode.java index 277f6b695..1369af832 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/MethodNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/MethodNode.java @@ -5,6 +5,7 @@ import java.util.Collections; import java.util.List; import java.util.Objects; +import org.jetbrains.annotations.ApiStatus; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; @@ -668,6 +669,13 @@ public class MethodNode extends NotificationAttrNode implements IMethodDetails, return insnsCount; } + /** + * Returns method code with comments and annotations + */ + public String getCodeStr() { + return CodeUtils.extractMethodCode(this, getTopParentClass().getCode()); + } + @Override public boolean isVarArg() { return accFlags.isVarArgs(); @@ -693,6 +701,7 @@ public class MethodNode extends NotificationAttrNode implements IMethodDetails, return javaNode; } + @ApiStatus.Internal public void setJavaNode(JavaMethod javaNode) { this.javaNode = javaNode; } diff --git a/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxAssertions.java b/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxAssertions.java index 163f1758c..7c4ab8ee0 100644 --- a/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxAssertions.java +++ b/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxAssertions.java @@ -4,6 +4,7 @@ import org.assertj.core.api.Assertions; import jadx.api.ICodeInfo; import jadx.core.dex.nodes.ClassNode; +import jadx.core.dex.nodes.MethodNode; public class JadxAssertions extends Assertions { @@ -12,6 +13,11 @@ public class JadxAssertions extends Assertions { return new JadxClassNodeAssertions(cls); } + public static JadxMethodNodeAssertions assertThat(MethodNode mth) { + Assertions.assertThat(mth).isNotNull(); + return new JadxMethodNodeAssertions(mth); + } + public static JadxCodeInfoAssertions assertThat(ICodeInfo codeInfo) { Assertions.assertThat(codeInfo).isNotNull(); return new JadxCodeInfoAssertions(codeInfo); diff --git a/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxMethodNodeAssertions.java b/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxMethodNodeAssertions.java new file mode 100644 index 000000000..2d0125e47 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxMethodNodeAssertions.java @@ -0,0 +1,20 @@ +package jadx.tests.api.utils.assertj; + +import org.assertj.core.api.AbstractObjectAssert; + +import jadx.core.dex.nodes.MethodNode; + +import static org.assertj.core.api.Assertions.assertThat; + +public class JadxMethodNodeAssertions extends AbstractObjectAssert { + public JadxMethodNodeAssertions(MethodNode mth) { + super(mth, JadxMethodNodeAssertions.class); + } + + public JadxCodeAssertions code() { + isNotNull(); + String codeStr = actual.getCodeStr(); + assertThat(codeStr).isNotBlank(); + return new JadxCodeAssertions(codeStr); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/external/BaseExternalTest.java b/jadx-core/src/test/java/jadx/tests/external/BaseExternalTest.java index e1095406e..6294a63fd 100644 --- a/jadx-core/src/test/java/jadx/tests/external/BaseExternalTest.java +++ b/jadx-core/src/test/java/jadx/tests/external/BaseExternalTest.java @@ -1,7 +1,6 @@ package jadx.tests.external; import java.io.File; -import java.util.function.BiFunction; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -13,9 +12,6 @@ import jadx.api.ICodeInfo; import jadx.api.JadxArgs; import jadx.api.JadxDecompiler; import jadx.api.JadxInternalAccess; -import jadx.api.metadata.ICodeAnnotation; -import jadx.api.metadata.ICodeNodeRef; -import jadx.api.metadata.annotations.NodeDeclareRef; import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.RootNode; @@ -134,58 +130,15 @@ public abstract class BaseExternalTest extends TestUtils { String dashLine = "======================================================================================"; for (MethodNode mth : classNode.getMethods()) { if (isMthMatch(mth, mthPattern)) { - String mthCode = cutMethodCode(codeInfo, mth); - LOG.info("Print method: {}\n{}\n{}\n{}", mth.getMethodInfo().getShortId(), + LOG.info("Print method: {}\n{}\n{}\n{}", + mth.getMethodInfo().getShortId(), dashLine, - mthCode, + mth.getCodeStr(), dashLine); } } } - private String cutMethodCode(ICodeInfo codeInfo, MethodNode mth) { - int startPos = getCommentStartPos(codeInfo, mth.getDefPosition()); - int stopPos = getMethodEnd(mth, codeInfo); - return codeInfo.getCodeStr().substring(startPos, stopPos); - } - - private int getMethodEnd(MethodNode mth, ICodeInfo codeInfo) { - // skip nested nodes DEF/END until first unpaired END annotation (end of this method) - Integer end = codeInfo.getCodeMetadata().searchDown(mth.getDefPosition() + 1, new BiFunction<>() { - int nested = 0; - - @Override - public Integer apply(Integer pos, ICodeAnnotation ann) { - switch (ann.getAnnType()) { - case DECLARATION: - ICodeNodeRef node = ((NodeDeclareRef) ann).getNode(); - switch (node.getAnnType()) { - case CLASS: - case METHOD: - nested++; - break; - } - break; - - case END: - if (nested == 0) { - return pos; - } - nested--; - break; - } - return null; - } - }); - return end != null ? end : codeInfo.getCodeStr().length(); - } - - protected int getCommentStartPos(ICodeInfo codeInfo, int pos) { - String emptyLine = "\n\n"; - int emptyLinePos = codeInfo.getCodeStr().lastIndexOf(emptyLine, pos); - return emptyLinePos == -1 ? pos : emptyLinePos + emptyLine.length(); - } - private void printErrorReport(JadxDecompiler jadx) { jadx.printErrorsReport(); assertThat(jadx.getErrorsCount()).isEqualTo(0); diff --git a/jadx-core/src/test/java/jadx/tests/integration/others/TestConstReplace.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestConstReplace.java index a44dcdf35..21bfd81b2 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/others/TestConstReplace.java +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestConstReplace.java @@ -22,9 +22,11 @@ public class TestConstReplace extends IntegrationTest { @Test public void test() { ClassNode cls = getClassNode(TestCls.class); - assertThat(cls).code().containsOne("return CONST_VALUE;"); MethodNode testMth = cls.searchMethodByShortName("test"); - assertThat(testMth).isNotNull(); + assertThat(testMth) + .code() + .print() + .containsOne("return CONST_VALUE;"); FieldNode constField = cls.searchFieldByName("CONST_VALUE"); assertThat(constField).isNotNull();