refactor: make input plugin api similar to pass plugins

This commit is contained in:
Skylot
2022-08-22 17:47:58 +01:00
parent 0c4d46ead5
commit a89dbc1152
38 changed files with 370 additions and 379 deletions
@@ -25,7 +25,7 @@ public class JadxArgsValidator {
private static void checkInputFiles(JadxDecompiler jadx, JadxArgs args) {
List<File> inputFiles = args.getInputFiles();
if (inputFiles.isEmpty() && jadx.getCustomLoads().isEmpty()) {
if (inputFiles.isEmpty() && jadx.getCustomCodeLoaders().isEmpty()) {
throw new JadxArgsValidateException("Please specify input file");
}
for (File inputFile : inputFiles) {
@@ -23,7 +23,7 @@ import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.api.impl.plugins.SimplePluginContext;
import jadx.api.impl.plugins.PluginsContext;
import jadx.api.metadata.ICodeAnnotation;
import jadx.api.metadata.ICodeNodeRef;
import jadx.api.metadata.annotations.NodeDeclareRef;
@@ -31,10 +31,8 @@ import jadx.api.metadata.annotations.VarNode;
import jadx.api.metadata.annotations.VarRef;
import jadx.api.plugins.JadxPlugin;
import jadx.api.plugins.JadxPluginManager;
import jadx.api.plugins.gui.JadxGuiContext;
import jadx.api.plugins.input.JadxInputPlugin;
import jadx.api.plugins.input.data.ILoadResult;
import jadx.api.plugins.options.JadxPluginOptions;
import jadx.api.plugins.input.ICodeLoader;
import jadx.api.plugins.input.JadxCodeInput;
import jadx.api.plugins.pass.JadxPass;
import jadx.api.plugins.pass.types.JadxAfterLoadPass;
import jadx.api.plugins.pass.types.JadxPassType;
@@ -88,7 +86,7 @@ public final class JadxDecompiler implements Closeable {
private final JadxArgs args;
private final JadxPluginManager pluginManager = new JadxPluginManager();
private final List<ILoadResult> loadedInputs = new ArrayList<>();
private final List<ICodeLoader> loadedInputs = new ArrayList<>();
private RootNode root;
private List<JavaClass> classes;
@@ -99,9 +97,9 @@ public final class JadxDecompiler implements Closeable {
private final IDecompileScheduler decompileScheduler = new DecompilerScheduler();
private final List<ILoadResult> customLoads = new ArrayList<>();
private final PluginsContext pluginsContext = new PluginsContext(this);
private final List<ICodeLoader> customCodeLoaders = new ArrayList<>();
private final Map<JadxPassType, List<JadxPass>> customPasses = new HashMap<>();
private @Nullable JadxGuiContext guiContext;
public JadxDecompiler() {
this(new JadxArgs());
@@ -135,13 +133,13 @@ public final class JadxDecompiler implements Closeable {
List<Path> inputPaths = Utils.collectionMap(args.getInputFiles(), File::toPath);
List<Path> inputFiles = FileUtils.expandDirs(inputPaths);
long start = System.currentTimeMillis();
for (JadxInputPlugin inputPlugin : pluginManager.getInputPlugins()) {
ILoadResult loadResult = inputPlugin.loadFiles(inputFiles);
if (loadResult != null && !loadResult.isEmpty()) {
loadedInputs.add(loadResult);
for (JadxCodeInput codeLoader : pluginsContext.getCodeInputs()) {
ICodeLoader loader = codeLoader.loadFiles(inputFiles);
if (loader != null && !loader.isEmpty()) {
loadedInputs.add(loader);
}
}
loadedInputs.addAll(customLoads);
loadedInputs.addAll(customCodeLoaders);
if (LOG.isDebugEnabled()) {
LOG.debug("Loaded using {} inputs plugin in {} ms", loadedInputs.size(), System.currentTimeMillis() - start);
}
@@ -180,39 +178,7 @@ public final class JadxDecompiler implements Closeable {
LOG.debug("Resolved plugins: {}", Utils.collectionMap(pluginManager.getResolvedPlugins(),
p -> p.getPluginInfo().getPluginId()));
}
applyPluginOptions();
initPlugins();
}
private void applyPluginOptions() {
Map<String, String> pluginOptions = args.getPluginOptions();
if (!pluginOptions.isEmpty()) {
LOG.debug("Applying plugin options: {}", pluginOptions);
for (JadxPluginOptions plugin : pluginManager.getPluginsWithOptions()) {
try {
plugin.setOptions(pluginOptions);
} catch (Exception e) {
String pluginId = plugin.getPluginInfo().getPluginId();
throw new JadxRuntimeException("Failed to apply options for plugin: " + pluginId, e);
}
}
}
}
private void initPlugins() {
customPasses.clear();
List<JadxPlugin> plugins = pluginManager.getResolvedPlugins();
SimplePluginContext context = new SimplePluginContext(this);
context.setGuiContext(guiContext);
for (JadxPlugin passPlugin : plugins) {
try {
passPlugin.init(context);
} catch (Exception e) {
String pluginId = passPlugin.getPluginInfo().getPluginId();
throw new JadxRuntimeException("Failed to pass plugin: " + pluginId, e);
}
}
pluginManager.initResolved(pluginsContext);
if (LOG.isDebugEnabled()) {
List<String> passes = customPasses.values().stream().flatMap(Collection::stream)
.map(p -> p.getInfo().getName()).collect(Collectors.toList());
@@ -684,20 +650,20 @@ public final class JadxDecompiler implements Closeable {
return decompileScheduler;
}
public void addCustomLoad(ILoadResult customLoad) {
customLoads.add(customLoad);
public void addCustomCodeLoader(ICodeLoader customCodeLoader) {
customCodeLoaders.add(customCodeLoader);
}
public List<ILoadResult> getCustomLoads() {
return customLoads;
public List<ICodeLoader> getCustomCodeLoaders() {
return customCodeLoaders;
}
public void addCustomPass(JadxPass pass) {
customPasses.computeIfAbsent(pass.getPassType(), l -> new ArrayList<>()).add(pass);
}
public void setJadxGuiContext(JadxGuiContext guiContext) {
this.guiContext = guiContext;
public PluginsContext getPluginsContext() {
return pluginsContext;
}
@Override
@@ -0,0 +1,86 @@
package jadx.api.impl.plugins;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.jetbrains.annotations.Nullable;
import jadx.api.JadxArgs;
import jadx.api.JadxDecompiler;
import jadx.api.plugins.JadxPlugin;
import jadx.api.plugins.JadxPluginContext;
import jadx.api.plugins.gui.JadxGuiContext;
import jadx.api.plugins.input.JadxCodeInput;
import jadx.api.plugins.options.JadxPluginOptions;
import jadx.api.plugins.pass.JadxPass;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class PluginsContext implements JadxPluginContext {
private final JadxDecompiler decompiler;
private final List<JadxCodeInput> codeInputs = new ArrayList<>();
private final Map<JadxPlugin, JadxPluginOptions> optionsMap = new IdentityHashMap<>();
private @Nullable JadxGuiContext guiContext;
private @Nullable JadxPlugin currentPlugin;
public PluginsContext(JadxDecompiler decompiler) {
this.decompiler = decompiler;
}
@Override
public JadxArgs getArgs() {
return decompiler.getArgs();
}
@Override
public JadxDecompiler getDecompiler() {
return decompiler;
}
@Override
public void addPass(JadxPass pass) {
decompiler.addCustomPass(pass);
}
@Override
public void addCodeInput(JadxCodeInput codeInput) {
codeInputs.add(codeInput);
}
public List<JadxCodeInput> getCodeInputs() {
return codeInputs;
}
public void setCurrentPlugin(JadxPlugin currentPlugin) {
this.currentPlugin = currentPlugin;
}
@Override
public void registerOptions(JadxPluginOptions options) {
Objects.requireNonNull(currentPlugin);
try {
options.setOptions(decompiler.getArgs().getPluginOptions());
optionsMap.put(currentPlugin, options);
} catch (Exception e) {
String pluginId = currentPlugin.getPluginInfo().getPluginId();
throw new JadxRuntimeException("Failed to apply options for plugin: " + pluginId, e);
}
}
public Map<JadxPlugin, JadxPluginOptions> getOptionsMap() {
return optionsMap;
}
@Override
public @Nullable JadxGuiContext getGuiContext() {
return guiContext;
}
public void setGuiContext(JadxGuiContext guiContext) {
this.guiContext = guiContext;
}
}
@@ -1,19 +0,0 @@
package jadx.api.impl.plugins;
import jadx.api.JadxDecompiler;
import jadx.api.plugins.pass.JadxPass;
import jadx.api.plugins.pass.JadxPassContext;
public class SimplePassContext implements JadxPassContext {
private final JadxDecompiler jadxDecompiler;
public SimplePassContext(JadxDecompiler jadxDecompiler) {
this.jadxDecompiler = jadxDecompiler;
}
@Override
public void addPass(JadxPass pass) {
jadxDecompiler.addCustomPass(pass);
}
}
@@ -1,45 +0,0 @@
package jadx.api.impl.plugins;
import org.jetbrains.annotations.Nullable;
import jadx.api.JadxArgs;
import jadx.api.JadxDecompiler;
import jadx.api.plugins.JadxPluginContext;
import jadx.api.plugins.gui.JadxGuiContext;
import jadx.api.plugins.pass.JadxPassContext;
public class SimplePluginContext implements JadxPluginContext {
private final JadxDecompiler decompiler;
private final JadxPassContext passContext;
private @Nullable JadxGuiContext guiContext;
public SimplePluginContext(JadxDecompiler decompiler) {
this.decompiler = decompiler;
this.passContext = new SimplePassContext(decompiler);
}
@Override
public JadxArgs getArgs() {
return decompiler.getArgs();
}
@Override
public JadxDecompiler getDecompiler() {
return decompiler;
}
@Override
public JadxPassContext getPassContext() {
return passContext;
}
@Override
public @Nullable JadxGuiContext getGuiContext() {
return guiContext;
}
public void setGuiContext(JadxGuiContext guiContext) {
this.guiContext = guiContext;
}
}
@@ -1,9 +1,8 @@
package jadx.api.plugins;
public interface JadxPlugin {
JadxPluginInfo getPluginInfo();
default void init(JadxPluginContext context) {
// default to no-op
}
void init(JadxPluginContext context);
}
@@ -5,7 +5,9 @@ import org.jetbrains.annotations.Nullable;
import jadx.api.JadxArgs;
import jadx.api.JadxDecompiler;
import jadx.api.plugins.gui.JadxGuiContext;
import jadx.api.plugins.pass.JadxPassContext;
import jadx.api.plugins.input.JadxCodeInput;
import jadx.api.plugins.options.JadxPluginOptions;
import jadx.api.plugins.pass.JadxPass;
public interface JadxPluginContext {
@@ -13,7 +15,11 @@ public interface JadxPluginContext {
JadxDecompiler getDecompiler();
JadxPassContext getPassContext();
void addPass(JadxPass pass);
void addCodeInput(JadxCodeInput codeInput);
void registerOptions(JadxPluginOptions options);
@Nullable
JadxGuiContext getGuiContext();
@@ -15,9 +15,10 @@ import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.api.plugins.input.JadxInputPlugin;
import jadx.api.impl.plugins.PluginsContext;
import jadx.api.plugins.options.JadxPluginOptions;
import jadx.api.plugins.options.OptionDescription;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class JadxPluginManager {
private static final Logger LOG = LoggerFactory.getLogger(JadxPluginManager.class);
@@ -59,34 +60,9 @@ public class JadxPluginManager {
if (!allPlugins.add(pluginData)) {
throw new IllegalArgumentException("Duplicate plugin id: " + pluginData + ", class " + plugin.getClass());
}
if (plugin instanceof JadxPluginOptions) {
verifyOptions(((JadxPluginOptions) plugin), pluginData.getPluginId());
}
return pluginData;
}
private void verifyOptions(JadxPluginOptions plugin, String pluginId) {
List<OptionDescription> descriptions = plugin.getOptionsDescriptions();
if (descriptions == null) {
throw new IllegalArgumentException("Null option descriptions in plugin id: " + pluginId);
}
String prefix = pluginId + '.';
descriptions.forEach(descObj -> {
String optName = descObj.name();
if (optName == null || !optName.startsWith(prefix)) {
throw new IllegalArgumentException("Plugin option name should start with plugin id: '" + prefix + "', option: " + optName);
}
String desc = descObj.description();
if (desc == null || desc.isEmpty()) {
throw new IllegalArgumentException("Plugin option description not set, plugin: " + pluginId);
}
List<String> values = descObj.values();
if (values == null) {
throw new IllegalArgumentException("Plugin option values is null, option: " + optName + ", plugin: " + pluginId);
}
});
}
public boolean unload(String pluginId) {
boolean result = allPlugins.removeIf(pd -> {
String id = pd.getPluginId();
@@ -111,22 +87,6 @@ public class JadxPluginManager {
return Collections.unmodifiableList(resolvedPlugins);
}
public List<JadxInputPlugin> getInputPlugins() {
return getPluginsWithType(JadxInputPlugin.class);
}
public List<JadxPluginOptions> getPluginsWithOptions() {
return getPluginsWithType(JadxPluginOptions.class);
}
@SuppressWarnings("unchecked")
public <T extends JadxPlugin> List<T> getPluginsWithType(Class<T> type) {
return resolvedPlugins.stream()
.filter(p -> type.isAssignableFrom(p.getClass()))
.map(p -> (T) p)
.collect(Collectors.toList());
}
private synchronized void resolve() {
Map<String, List<PluginData>> provides = allPlugins.stream()
.collect(Collectors.groupingBy(p -> p.getInfo().getProvides()));
@@ -151,6 +111,52 @@ public class JadxPluginManager {
resolvedPlugins = result.stream().map(PluginData::getPlugin).collect(Collectors.toList());
}
public void initAll(PluginsContext context) {
init(context, getAllPlugins());
}
public void initResolved(PluginsContext context) {
init(context, resolvedPlugins);
}
private void init(PluginsContext context, List<JadxPlugin> plugins) {
for (JadxPlugin plugin : plugins) {
try {
context.setCurrentPlugin(plugin);
plugin.init(context);
} catch (Exception e) {
String pluginId = plugin.getPluginInfo().getPluginId();
throw new JadxRuntimeException("Failed to init plugin: " + pluginId, e);
}
}
for (Map.Entry<JadxPlugin, JadxPluginOptions> entry : context.getOptionsMap().entrySet()) {
verifyOptions(entry.getKey(), entry.getValue());
}
}
private void verifyOptions(JadxPlugin plugin, JadxPluginOptions options) {
String pluginId = plugin.getPluginInfo().getPluginId();
List<OptionDescription> descriptions = options.getOptionsDescriptions();
if (descriptions == null) {
throw new IllegalArgumentException("Null option descriptions in plugin id: " + pluginId);
}
String prefix = pluginId + '.';
descriptions.forEach(descObj -> {
String optName = descObj.name();
if (optName == null || !optName.startsWith(prefix)) {
throw new IllegalArgumentException("Plugin option name should start with plugin id: '" + prefix + "', option: " + optName);
}
String desc = descObj.description();
if (desc == null || desc.isEmpty()) {
throw new IllegalArgumentException("Plugin option description not set, plugin: " + pluginId);
}
List<String> values = descObj.values();
if (values == null) {
throw new IllegalArgumentException("Plugin option values is null, option: " + optName + ", plugin: " + pluginId);
}
});
}
private static final class PluginData implements Comparable<PluginData> {
private final JadxPlugin plugin;
private final JadxPluginInfo info;
@@ -0,0 +1,13 @@
package jadx.api.plugins.input;
import java.io.Closeable;
import java.util.function.Consumer;
import jadx.api.plugins.input.data.IClassData;
public interface ICodeLoader extends Closeable {
void visitClasses(Consumer<IClassData> consumer);
boolean isEmpty();
}
@@ -0,0 +1,8 @@
package jadx.api.plugins.input;
import java.nio.file.Path;
import java.util.List;
public interface JadxCodeInput {
ICodeLoader loadFiles(List<Path> input);
}
@@ -1,11 +0,0 @@
package jadx.api.plugins.input;
import java.nio.file.Path;
import java.util.List;
import jadx.api.plugins.JadxPlugin;
import jadx.api.plugins.input.data.ILoadResult;
public interface JadxInputPlugin extends JadxPlugin {
ILoadResult loadFiles(List<Path> input);
}
@@ -1,12 +0,0 @@
package jadx.api.plugins.input.data;
import java.io.Closeable;
import java.util.function.Consumer;
public interface ILoadResult extends Closeable {
void visitClasses(Consumer<IClassData> consumer);
void visitResources(Consumer<IResourceData> consumer);
boolean isEmpty();
}
@@ -3,13 +3,12 @@ package jadx.api.plugins.input.data.impl;
import java.io.IOException;
import java.util.function.Consumer;
import jadx.api.plugins.input.ICodeLoader;
import jadx.api.plugins.input.data.IClassData;
import jadx.api.plugins.input.data.ILoadResult;
import jadx.api.plugins.input.data.IResourceData;
public class EmptyLoadResult implements ILoadResult {
public class EmptyCodeLoader implements ICodeLoader {
public static final EmptyLoadResult INSTANCE = new EmptyLoadResult();
public static final EmptyCodeLoader INSTANCE = new EmptyCodeLoader();
@Override
public boolean isEmpty() {
@@ -20,10 +19,6 @@ public class EmptyLoadResult implements ILoadResult {
public void visitClasses(Consumer<IClassData> consumer) {
}
@Override
public void visitResources(Consumer<IResourceData> consumer) {
}
@Override
public void close() throws IOException {
}
@@ -3,9 +3,7 @@ package jadx.api.plugins.options;
import java.util.List;
import java.util.Map;
import jadx.api.plugins.JadxPlugin;
public interface JadxPluginOptions extends JadxPlugin {
public interface JadxPluginOptions {
void setOptions(Map<String, String> options);
@@ -4,9 +4,21 @@ import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
public class BaseOptionsParser {
import jadx.api.plugins.options.JadxPluginOptions;
public boolean getBooleanOption(Map<String, String> options, String key, boolean defValue) {
public abstract class BaseOptionsParser implements JadxPluginOptions {
protected Map<String, String> options;
@Override
public void setOptions(Map<String, String> options) {
this.options = options;
parseOptions();
}
public abstract void parseOptions();
public boolean getBooleanOption(String key, boolean defValue) {
String val = options.get(key);
if (val == null) {
return defValue;
@@ -22,7 +34,7 @@ public class BaseOptionsParser {
+ ", expect: 'yes' or 'no'");
}
public <T> T getOption(Map<String, String> options, String key, Function<String, T> parse, T defValue) {
public <T> T getOption(String key, Function<String, T> parse, T defValue) {
String val = options.get(key);
if (val == null) {
return defValue;
@@ -1,6 +0,0 @@
package jadx.api.plugins.pass;
public interface JadxPassContext {
void addPass(JadxPass pass);
}
@@ -23,8 +23,8 @@ import jadx.api.ResourcesLoader;
import jadx.api.data.ICodeData;
import jadx.api.impl.passes.DecompilePassWrapper;
import jadx.api.impl.passes.PreparePassWrapper;
import jadx.api.plugins.input.ICodeLoader;
import jadx.api.plugins.input.data.IClassData;
import jadx.api.plugins.input.data.ILoadResult;
import jadx.api.plugins.pass.JadxPass;
import jadx.api.plugins.pass.types.JadxDecompilePass;
import jadx.api.plugins.pass.types.JadxPassType;
@@ -115,9 +115,9 @@ public class RootNode {
}
}
public void loadClasses(List<ILoadResult> loadedInputs) {
for (ILoadResult loadedInput : loadedInputs) {
loadedInput.visitClasses(cls -> {
public void loadClasses(List<ICodeLoader> loadedInputs) {
for (ICodeLoader codeLoader : loadedInputs) {
codeLoader.visitClasses(cls -> {
try {
addClassNode(new ClassNode(RootNode.this, cls));
} catch (Exception e) {
@@ -46,7 +46,7 @@ public class JadxDecompilerTest {
public void testDirectDexInput() throws IOException {
try (JadxDecompiler jadx = new JadxDecompiler();
InputStream in = new FileInputStream(getFileFromSampleDir("hello.dex"))) {
jadx.addCustomLoad(new DexInputPlugin().loadDexFromInputStream(in, "input"));
jadx.addCustomCodeLoader(new DexInputPlugin().loadDexFromInputStream(in, "input"));
jadx.load();
for (JavaClass cls : jadx.getClasses()) {
System.out.println(cls.getCode());