fix: improve switch over string restore (#2359)

This commit is contained in:
Skylot
2026-01-21 20:02:05 +00:00
parent 54265e34e5
commit ad267e1618
10 changed files with 264 additions and 83 deletions
@@ -483,7 +483,7 @@ public class InsnDecoder {
case FILL_ARRAY_DATA:
return new FillArrayInsn(InsnArg.reg(insn, 0, ArgType.UNKNOWN_ARRAY), insn.getTarget());
case FILL_ARRAY_DATA_PAYLOAD:
return new FillArrayData(((IArrayPayload) Objects.requireNonNull(insn.getPayload())));
return new FillArrayData((IArrayPayload) Objects.requireNonNull(insn.getPayload()));
case FILLED_NEW_ARRAY:
return filledNewArray(insn, false);
@@ -497,7 +497,7 @@ public class InsnDecoder {
case PACKED_SWITCH_PAYLOAD:
case SPARSE_SWITCH_PAYLOAD:
return new SwitchData(((ISwitchPayload) insn.getPayload()));
return new SwitchData((ISwitchPayload) insn.getPayload());
case MONITOR_ENTER:
return insn(InsnType.MONITOR_ENTER,
@@ -515,7 +515,7 @@ public class InsnDecoder {
}
private SwitchInsn makeSwitch(InsnData insn, boolean packed) {
SwitchInsn swInsn = new SwitchInsn(InsnArg.reg(insn, 0, ArgType.UNKNOWN), insn.getTarget(), packed);
SwitchInsn swInsn = new SwitchInsn(InsnArg.reg(insn, 0, ArgType.NARROW_INTEGRAL), insn.getTarget(), packed);
ICustomPayload payload = insn.getPayload();
if (payload != null) {
swInsn.attachSwitchData(new SwitchData((ISwitchPayload) payload), insn.getTarget());
@@ -69,6 +69,15 @@ public final class PhiInsn extends InsnNode {
return (RegisterArg) super.getArg(n);
}
public @Nullable RegisterArg getArgByBlock(BlockNode block) {
for (int i = 0; i < blockBinds.size(); i++) {
if (blockBinds.get(i) == block) {
return getArg(i);
}
}
return null;
}
@Override
public boolean removeArg(InsnArg arg) {
int index = getArgIndex(arg);
@@ -61,6 +61,9 @@ public abstract class ArgType {
PrimitiveType.INT, PrimitiveType.FLOAT,
PrimitiveType.SHORT, PrimitiveType.BYTE, PrimitiveType.CHAR);
public static final ArgType NARROW_NEG_NUMBERS = unknown(
PrimitiveType.INT, PrimitiveType.SHORT, PrimitiveType.BYTE, PrimitiveType.FLOAT);
public static final ArgType NARROW_NUMBERS_NO_FLOAT = unknown(
PrimitiveType.INT, PrimitiveType.BOOLEAN,
PrimitiveType.SHORT, PrimitiveType.BYTE, PrimitiveType.CHAR);
@@ -23,6 +23,9 @@ public final class LiteralArg extends InsnArg {
if (value == 1) {
return ArgType.NARROW_NUMBERS;
}
if (value < 0) {
return ArgType.NARROW_NEG_NUMBERS;
}
return ArgType.NARROW_NUMBERS_NO_BOOL;
}
@@ -41,6 +41,10 @@ public final class SwitchRegion extends AbstractRegion implements IBranchRegion
this.container = container;
}
public boolean isDefaultCase() {
return keys.size() == 1 && keys.get(0) == DEFAULT_CASE_KEY;
}
public List<Object> getKeys() {
return keys;
}
@@ -18,6 +18,7 @@ import jadx.api.plugins.input.data.attributes.IJadxAttrType;
import jadx.api.plugins.input.data.attributes.IJadxAttribute;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.CodeFeaturesAttr;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg;
@@ -315,6 +316,13 @@ public class BlockProcessor extends AbstractVisitor {
if (mergeConstReturn(mth)) {
return true;
}
if (CodeFeaturesAttr.contains(mth, CodeFeaturesAttr.CodeFeature.SWITCH)) {
for (BlockNode basicBlock : mth.getBasicBlocks()) {
if (duplicateSimpleMoveBlock(mth, basicBlock)) {
return true;
}
}
}
return splitExitBlocks(mth);
}
@@ -383,6 +391,65 @@ public class BlockProcessor extends AbstractVisitor {
return changed;
}
/**
* Duplicate block if it contains only one 'move' insn and all predecessors are 'switch' and 'if'.
* This will help to resolve switch cases order and fallthrough detection
* because such move blocks can be deduplicated by compiler.
*/
private static boolean duplicateSimpleMoveBlock(MethodNode mth, BlockNode block) {
List<InsnNode> insns = block.getInstructions();
if (insns.size() == 1 && block.getSuccessors().size() == 1) {
InsnNode insn = insns.get(0);
if (insn.getType() == InsnType.MOVE) {
List<BlockNode> preds = block.getPredecessors();
int predSize = preds.size();
if (predSize >= 3 && onlySwitchAndIfInLastInsns(preds)) {
// confirmed, duplicate block
BlockNode successor = block.getSuccessors().get(0);
List<BlockNode> predsCopy = new ArrayList<>(preds);
for (int i = 1; i < predSize; i++) {
BlockNode pred = predsCopy.get(i);
BlockNode newBlock = BlockSplitter.startNewBlock(mth, -1);
newBlock.add(AFlag.SYNTHETIC);
for (InsnNode oldInsn : block.getInstructions()) {
InsnNode copyInsn = oldInsn.copyWithoutSsa();
copyInsn.add(AFlag.SYNTHETIC);
newBlock.getInstructions().add(copyInsn);
}
newBlock.copyAttributesFrom(block);
BlockSplitter.replaceConnection(pred, block, newBlock);
BlockSplitter.connect(newBlock, successor);
}
return true;
}
}
}
return false;
}
private static boolean onlySwitchAndIfInLastInsns(List<BlockNode> preds) {
boolean hasSwitch = false;
boolean hasIf = false;
for (BlockNode pred : preds) {
InsnNode lastInsn = BlockUtils.getLastInsn(pred);
if (lastInsn == null) {
return false;
}
InsnType insnType = lastInsn.getType();
switch (insnType) {
case SWITCH:
hasSwitch = true;
break;
case IF:
hasIf = true;
break;
default:
return false;
}
}
return hasSwitch && hasIf;
}
private static boolean simplifyLoopEnd(MethodNode mth, LoopInfo loop) {
BlockNode loopEnd = loop.getEnd();
if (loopEnd.getSuccessors().size() <= 1) {
@@ -13,6 +13,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
import org.jetbrains.annotations.Nullable;
import jadx.api.plugins.input.data.annotations.EncodedType;
import jadx.api.plugins.input.data.annotations.EncodedValue;
import jadx.api.plugins.input.data.attributes.JadxAttrType;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.IAttributeNode;
@@ -23,17 +24,20 @@ import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.conditions.Compare;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.visitors.AbstractVisitor;
@@ -42,6 +46,7 @@ import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.RegionUtils;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
@@ -79,19 +84,21 @@ public class SwitchOverStringVisitor extends AbstractVisitor implements IRegionI
return false;
}
int casesCount = switchRegion.getCases().size();
boolean defaultCaseAdded = switchRegion.getCases().stream().anyMatch(SwitchRegion.CaseInfo::isDefaultCase);
int casesWithString = defaultCaseAdded ? casesCount - 1 : casesCount;
SSAVar strVar = strArg.getSVar();
if (strVar.getUseCount() - 1 < casesCount) {
if (strVar.getUseCount() - 1 < casesWithString) {
// one 'hashCode' invoke and at least one 'equals' per case
return false;
}
// quick checks done, start collecting data to create a new switch region
Map<InsnNode, String> strEqInsns = collectEqualsInsns(mth, strVar);
if (strEqInsns.size() < casesCount) {
if (strEqInsns.size() < casesWithString) {
return false;
}
SwitchData switchData = new SwitchData(mth, switchRegion);
switchData.setStrEqInsns(strEqInsns);
switchData.setCases(new ArrayList<>(strEqInsns.size()));
switchData.setCases(new ArrayList<>(casesCount));
for (SwitchRegion.CaseInfo swCaseInfo : switchRegion.getCases()) {
if (!processCase(switchData, swCaseInfo)) {
mth.addWarnComment("Failed to restore switch over string. Please report as a decompilation issue");
@@ -157,32 +164,27 @@ public class SwitchOverStringVisitor extends AbstractVisitor implements IRegionI
}
private boolean mergeWithCode(SwitchData switchData) {
// check for second switch
IContainer nextContainer = RegionUtils.getNextContainer(switchData.getMth(), switchData.getSwitchRegion());
if (!(nextContainer instanceof SwitchRegion)) {
return false;
}
SwitchRegion codeSwitch = (SwitchRegion) nextContainer;
InsnNode swInsn = BlockUtils.getLastInsnWithType(codeSwitch.getHeader(), InsnType.SWITCH);
if (swInsn == null || !swInsn.getArg(0).isRegister()) {
return false;
}
RegisterArg numArg = (RegisterArg) swInsn.getArg(0);
List<CaseData> cases = switchData.getCases();
// search index assign in cases code
RegisterArg numArg = null;
int extracted = 0;
for (CaseData caseData : cases) {
IContainer container = caseData.getCode();
List<InsnNode> insns = RegionUtils.collectInsns(switchData.getMth(), container);
insns.removeIf(i -> i.getType() == InsnType.BREAK);
if (insns.size() != 1) {
continue;
}
InsnNode numInsn = insns.get(0);
if (numInsn.getArgsCount() == 1) {
Object constVal = InsnUtils.getConstValueByArg(switchData.getMth().root(), numInsn.getArg(0));
if (constVal instanceof LiteralArg) {
if (numArg == null) {
numArg = numInsn.getResult();
} else {
if (!numArg.sameCodeVar(numInsn.getResult())) {
return false;
}
}
int num = (int) ((LiteralArg) constVal).getLiteral();
caseData.setCodeNum(num);
extracted++;
}
InsnNode numInsn = searchConstInsn(switchData, caseData, swInsn);
Integer num = extractConstNumber(switchData, numInsn, numArg);
if (num != null) {
caseData.setCodeNum(num);
extracted++;
}
}
if (extracted == 0) {
@@ -195,16 +197,7 @@ public class SwitchOverStringVisitor extends AbstractVisitor implements IRegionI
// TODO: additional checks for found index numbers
cases.sort(Comparator.comparingInt(CaseData::getCodeNum));
// extract complete, second switch on 'numArg' should be the next region
IContainer nextContainer = RegionUtils.getNextContainer(switchData.getMth(), switchData.getSwitchRegion());
if (!(nextContainer instanceof SwitchRegion)) {
return false;
}
SwitchRegion codeSwitch = (SwitchRegion) nextContainer;
InsnNode swInsn = BlockUtils.getLastInsnWithType(codeSwitch.getHeader(), InsnType.SWITCH);
if (swInsn == null || !swInsn.getArg(0).isSameCodeVar(numArg)) {
return false;
}
// extract complete
Map<Integer, CaseData> casesMap = new HashMap<>(cases.size());
for (CaseData caseData : cases) {
CaseData prev = casesMap.put(caseData.getCodeNum(), caseData);
@@ -215,42 +208,39 @@ public class SwitchOverStringVisitor extends AbstractVisitor implements IRegionI
block -> switchData.getToRemove().add(block));
}
final var newCases = new ArrayList<SwitchRegion.CaseInfo>();
List<SwitchRegion.CaseInfo> newCases = new ArrayList<>();
for (SwitchRegion.CaseInfo caseInfo : codeSwitch.getCases()) {
SwitchRegion.CaseInfo newCase = null;
for (Object key : caseInfo.getKeys()) {
final Integer intKey = unwrapIntKey(key);
Integer intKey = unwrapIntKey(key);
if (intKey != null) {
final var caseData = casesMap.remove(intKey);
CaseData caseData = casesMap.remove(intKey);
if (caseData == null) {
return false;
}
if (newCase == null) {
final List<Object> keys = new ArrayList<>(caseData.getStrValues());
List<Object> keys = new ArrayList<>(caseData.getStrValues());
newCase = new SwitchRegion.CaseInfo(keys, caseInfo.getContainer());
} else {
// merge cases
newCase.getKeys().addAll(caseData.getStrValues());
}
} else if (key == SwitchRegion.DEFAULT_CASE_KEY) {
final var iterator = casesMap.entrySet().iterator();
var iterator = casesMap.entrySet().iterator();
while (iterator.hasNext()) {
final var caseData = iterator.next().getValue();
CaseData caseData = iterator.next().getValue();
if (newCase == null) {
final List<Object> keys = new ArrayList<>(caseData.getStrValues());
List<Object> keys = new ArrayList<>(caseData.getStrValues());
newCase = new SwitchRegion.CaseInfo(keys, caseInfo.getContainer());
} else {
// merge cases
newCase.getKeys().addAll(caseData.getStrValues());
}
iterator.remove();
}
if (newCase == null) {
newCase = new SwitchRegion.CaseInfo(new ArrayList<>(), caseInfo.getContainer());
}
newCase.getKeys().add(SwitchRegion.DEFAULT_CASE_KEY);
} else {
return false;
@@ -258,25 +248,61 @@ public class SwitchOverStringVisitor extends AbstractVisitor implements IRegionI
}
newCases.add(newCase);
}
switchData.setCodeSwitch(codeSwitch);
switchData.setNumArg(numArg);
switchData.setNewCases(newCases);
return true;
}
private @Nullable Integer extractConstNumber(SwitchData switchData, @Nullable InsnNode numInsn, RegisterArg numArg) {
if (numInsn == null || numInsn.getArgsCount() != 1) {
return null;
}
Object constVal = InsnUtils.getConstValueByArg(switchData.getMth().root(), numInsn.getArg(0));
if (constVal instanceof LiteralArg) {
if (numArg.sameCodeVar(numInsn.getResult())) {
return (int) ((LiteralArg) constVal).getLiteral();
}
}
return null;
}
private static @Nullable InsnNode searchConstInsn(SwitchData switchData, CaseData caseData, InsnNode swInsn) {
IContainer container = caseData.getCode();
if (container != null) {
List<InsnNode> insns = RegionUtils.collectInsns(switchData.getMth(), container);
insns.removeIf(i -> i.getType() == InsnType.BREAK);
if (insns.size() == 1) {
return insns.get(0);
}
} else if (caseData.getBlockRef() != null) {
// variable used unchanged on path from block ref
BlockNode blockRef = caseData.getBlockRef();
InsnArg swArg = swInsn.getArg(0);
if (swArg.isRegister()) {
InsnNode assignInsn = ((RegisterArg) swArg).getSVar().getAssignInsn();
if (assignInsn != null && assignInsn.getType() == InsnType.PHI) {
RegisterArg arg = ((PhiInsn) assignInsn).getArgByBlock(blockRef);
if (arg != null) {
return arg.getAssignInsn();
}
}
}
}
return null;
}
private Integer unwrapIntKey(Object key) {
if (key instanceof Integer) {
return (Integer) key;
} else if (key instanceof FieldNode) {
final var encodedValue = ((FieldNode) key).get(JadxAttrType.CONSTANT_VALUE);
}
if (key instanceof FieldNode) {
EncodedValue encodedValue = ((FieldNode) key).get(JadxAttrType.CONSTANT_VALUE);
if (encodedValue != null && encodedValue.getType() == EncodedType.ENCODED_INT) {
return (Integer) encodedValue.getValue();
} else {
return null;
}
return null;
}
return null;
}
@@ -299,6 +325,11 @@ public class SwitchOverStringVisitor extends AbstractVisitor implements IRegionI
}
private boolean processCase(SwitchData switchData, SwitchRegion.CaseInfo caseInfo) {
if (caseInfo.isDefaultCase()) {
CaseData caseData = new CaseData();
caseData.setCode(caseInfo.getContainer());
return true;
}
AtomicBoolean fail = new AtomicBoolean(false);
RegionUtils.visitRegions(switchData.getMth(), caseInfo.getContainer(), region -> {
if (fail.get()) {
@@ -324,30 +355,39 @@ public class SwitchOverStringVisitor extends AbstractVisitor implements IRegionI
condition = condition.getArgs().get(0);
neg = true;
}
Compare compare = condition.getCompare();
if (compare == null) {
return null;
}
IfNode ifInsn = compare.getInsn();
InsnArg firstArg = ifInsn.getArg(0);
String str = null;
if (condition.isCompare()) {
IfNode ifInsn = condition.getCompare().getInsn();
InsnArg firstArg = ifInsn.getArg(0);
if (firstArg.isInsnWrap()) {
str = switchData.getStrEqInsns().get(((InsnWrapArg) firstArg).getWrapInsn());
}
if (ifInsn.getOp() == IfOp.NE && ifInsn.getArg(1).isTrue()) {
neg = true;
}
if (ifInsn.getOp() == IfOp.EQ && ifInsn.getArg(1).isFalse()) {
neg = true;
}
if (str != null) {
switchData.getToRemove().add(ifInsn);
switchData.getToRemove().addAll(ifRegion.getConditionBlocks());
}
if (firstArg.isInsnWrap()) {
str = switchData.getStrEqInsns().get(((InsnWrapArg) firstArg).getWrapInsn());
}
if (str == null) {
return null;
}
if (ifInsn.getOp() == IfOp.NE && ifInsn.getArg(1).isTrue()) {
neg = true;
}
if (ifInsn.getOp() == IfOp.EQ && ifInsn.getArg(1).isFalse()) {
neg = true;
}
switchData.getToRemove().add(ifInsn);
switchData.getToRemove().addAll(ifRegion.getConditionBlocks());
CaseData caseData = new CaseData();
caseData.getStrValues().add(str);
caseData.setCode(neg ? ifRegion.getElseRegion() : ifRegion.getThenRegion());
IContainer codeContainer = neg ? ifRegion.getElseRegion() : ifRegion.getThenRegion();
if (codeContainer == null) {
// no code
// use last condition block for later data tracing
caseData.setBlockRef(Utils.last(ifRegion.getConditionBlocks()));
} else {
caseData.setCode(codeContainer);
}
return caseData;
}
@@ -447,21 +487,30 @@ public class SwitchOverStringVisitor extends AbstractVisitor implements IRegionI
private static final class CaseData {
private final List<String> strValues = new ArrayList<>();
private IContainer code = null;
private @Nullable IContainer code = null;
private @Nullable BlockNode blockRef = null;
private int codeNum = -1;
public List<String> getStrValues() {
return strValues;
}
public IContainer getCode() {
public @Nullable IContainer getCode() {
return code;
}
public void setCode(IContainer code) {
public void setCode(@Nullable IContainer code) {
this.code = code;
}
public @Nullable BlockNode getBlockRef() {
return blockRef;
}
public void setBlockRef(@Nullable BlockNode blockRef) {
this.blockRef = blockRef;
}
public int getCodeNum() {
return codeNum;
}
@@ -447,14 +447,9 @@ public final class TypeUpdate {
boolean assignChanged = isAssign(insn, arg);
InsnArg changeArg = assignChanged ? insn.getArg(0) : insn.getResult();
boolean correctType;
if (changeArg.getType().isTypeKnown()) {
// allow result to be wider
TypeCompareEnum cmp = comparator.compareTypes(candidateType, changeArg.getType());
correctType = cmp.isEqual() || (assignChanged ? cmp.isWider() : cmp.isNarrow());
} else {
correctType = true;
}
// allow result to be wider
TypeCompareEnum cmp = comparator.compareTypes(candidateType, changeArg.getType());
boolean correctType = cmp.isEqual() || (assignChanged ? cmp.isWider() : cmp.isNarrow());
TypeUpdateResult result = updateTypeChecked(updateInfo, changeArg, candidateType);
if (result == SAME && !correctType) {
@@ -49,6 +49,8 @@ import jadx.core.utils.exceptions.JadxException;
public class DebugUtils {
private static final Logger LOG = LoggerFactory.getLogger(DebugUtils.class);
public static final Predicate<MethodNode> TEST_MTH_FILTER = mth -> mth.getName().equals("test");
private DebugUtils() {
}
@@ -63,7 +65,7 @@ public class DebugUtils {
}
public static void dumpRawTest(MethodNode mth, String desc) {
dumpRaw(mth, desc, method -> method.getName().equals("test"));
dumpRaw(mth, desc, TEST_MTH_FILTER);
}
public static void dumpRaw(MethodNode mth, String desc) {
@@ -91,6 +93,10 @@ public class DebugUtils {
};
}
public static IDexTreeVisitor dumpRawTestVisitor(String desc) {
return dumpRawVisitor(desc, TEST_MTH_FILTER);
}
public static void dump(MethodNode mth, String desc) {
File out = new File("test-graph-" + desc + "-tmp");
DotGraphVisitor.dump().save(out, mth);
@@ -0,0 +1,45 @@
package jadx.tests.integration.switches;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestSwitchOverStrings3 extends IntegrationTest {
@SuppressWarnings("SwitchStatementWithTooFewBranches")
public static class TestCls {
public int test(String v) {
switch (v) {
case "a":
return 1;
default:
switch (v) {
case "b":
return 2;
case "c":
return 3;
default:
return 4;
}
}
}
public void check() {
assertThat(test("a")).isEqualTo(1);
assertThat(test("b")).isEqualTo(2);
assertThat(test("c")).isEqualTo(3);
assertThat(test("d")).isEqualTo(4);
}
}
@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.countString(3, "case ")
.countString(2, "default:")
.countString(4, "return ");
}
}