Fix some rgblink object file input bugs found via fuzzing with AFL++ (#1867)

- ID numbers (for fstack nodes, sections, symbols, patches, etc)
  might be too large for their associated collection
- Enum values might be invalid
- Bank values might be out of range for their section types
This commit is contained in:
Rangi
2025-12-04 20:49:16 -05:00
committed by GitHub
parent 8d6c617875
commit 131bb97ebc
4 changed files with 138 additions and 59 deletions

View File

@@ -120,7 +120,7 @@ enum SectionModifier { SECTION_NORMAL, SECTION_UNION, SECTION_FRAGMENT };
extern char const * const sectionModNames[];
enum ExportLevel { SYMTYPE_LOCAL, SYMTYPE_IMPORT, SYMTYPE_EXPORT };
enum ExportLevel { SYMTYPE_LOCAL, SYMTYPE_IMPORT, SYMTYPE_EXPORT, SYMTYPE_INVALID };
enum PatchType {
PATCHTYPE_BYTE,

View File

@@ -119,6 +119,15 @@ static std::optional<size_t> getPlacement(Section const &section, MemoryLocation
SectionTypeInfo const &typeInfo = sectionTypeInfo[section.type];
for (;;) {
if (location.bank < typeInfo.firstBank
|| location.bank >= memory[section.type].size() + typeInfo.firstBank) {
fatal(
"Invalid bank for %s section: 0x%02x",
sectionTypeInfo[section.type].name.c_str(),
location.bank
);
}
// Switch to the beginning of the next bank
std::deque<FreeSpace> &bankMem = memory[section.type][location.bank - typeInfo.firstBank];
size_t spaceIdx = 0;

View File

@@ -72,7 +72,7 @@ static int64_t readLong(FILE *file) {
tryRead(readLong, int64_t, INT64_MAX, long, var, file, __VA_ARGS__)
// Helper macro to read a byte from a file to a var, or error out if it fails to.
#define tryGetc(type, var, file, ...) tryRead(getc, int, EOF, type, var, file, __VA_ARGS__)
#define tryGetc(var, file, ...) tryRead(getc, int, EOF, uint8_t, var, file, __VA_ARGS__)
// Helper macro to read a '\0'-terminated string from a file, or error out if it fails to.
#define tryReadString(var, file, ...) \
@@ -100,27 +100,31 @@ static void readFileStackNode(
tryReadLong(
parentID, file, "%s: Cannot read node #%" PRIu32 "'s parent ID: %s", fileName, nodeID
);
node.parent = parentID != UINT32_MAX ? &fileNodes[parentID] : nullptr;
if (parentID == UINT32_MAX) {
node.parent = nullptr;
} else if (parentID >= fileNodes.size()) {
fatal("%s: Node #%" PRIu32 " has invalid parent ID #%" PRIu32, fileName, nodeID, parentID);
} else {
node.parent = &fileNodes[parentID];
}
tryReadLong(
node.lineNo, file, "%s: Cannot read node #%" PRIu32 "'s line number: %s", fileName, nodeID
);
uint8_t type;
tryGetc(uint8_t, type, file, "%s: Cannot read node #%" PRIu32 "'s type: %s", fileName, nodeID);
node.type = static_cast<FileStackNodeType>(type & ~(1 << FSTACKNODE_QUIET_BIT));
node.isQuiet = (type & (1 << FSTACKNODE_QUIET_BIT)) != 0;
switch (node.type) {
tryGetc(type, file, "%s: Cannot read node #%" PRIu32 "'s type: %s", fileName, nodeID);
switch (type & ~(1 << FSTACKNODE_QUIET_BIT)) {
case NODE_FILE:
case NODE_MACRO:
node.type = FileStackNodeType(type);
node.data = "";
tryReadString(
node.name(), file, "%s: Cannot read node #%" PRIu32 "'s file name: %s", fileName, nodeID
);
break;
case NODE_REPT: {
node.type = NODE_REPT;
uint32_t depth;
tryReadLong(
depth, file, "%s: Cannot read node #%" PRIu32 "'s REPT depth: %s", fileName, nodeID
@@ -143,8 +147,13 @@ static void readFileStackNode(
nodeID
);
}
break;
}
default:
fatal("%s: Node #%" PRIu32 " has unknown type 0x%02x", fileName, nodeID, type);
}
node.isQuiet = (type & (1 << FSTACKNODE_QUIET_BIT)) != 0;
}
// Reads a symbol from a file.
@@ -152,20 +161,25 @@ static void readSymbol(
FILE *file, Symbol &symbol, char const *fileName, std::vector<FileStackNode> const &fileNodes
) {
tryReadString(symbol.name, file, "%s: Cannot read symbol name: %s", fileName);
tryGetc(
ExportLevel,
symbol.type,
file,
"%s: Cannot read `%s`'s type: %s",
fileName,
symbol.name.c_str()
);
uint8_t type;
tryGetc(type, file, "%s: Cannot read `%s`'s type: %s", fileName, symbol.name.c_str());
if (type >= SYMTYPE_INVALID) {
fatal("%s: `%s` has unknown type 0x%02x", fileName, symbol.name.c_str(), type);
} else {
symbol.type = ExportLevel(type);
}
// If the symbol is defined in this file, read its definition
if (symbol.type != SYMTYPE_IMPORT) {
uint32_t nodeID;
tryReadLong(
nodeID, file, "%s: Cannot read `%s`'s node ID: %s", fileName, symbol.name.c_str()
);
if (nodeID >= fileNodes.size()) {
fatal("%s: `%s` has invalid node ID #%" PRIu32, fileName, symbol.name.c_str(), nodeID);
}
symbol.src = &fileNodes[nodeID];
tryReadLong(
symbol.lineNo,
@@ -212,6 +226,15 @@ static void readPatch(
sectName.c_str(),
patchID
);
if (nodeID >= fileNodes.size()) {
fatal(
"%s: \"%s\"'s patch #%" PRIu32 " has invalid node ID #%" PRIu32,
fileName,
sectName.c_str(),
patchID,
nodeID
);
}
patch.src = &fileNodes[nodeID];
tryReadLong(
@@ -247,9 +270,8 @@ static void readPatch(
patchID
);
PatchType type;
uint8_t type;
tryGetc(
PatchType,
type,
file,
"%s: Cannot read \"%s\"'s patch #%" PRIu32 "'s type: %s",
@@ -257,7 +279,17 @@ static void readPatch(
sectName.c_str(),
patchID
);
patch.type = type;
if (type >= PATCHTYPE_INVALID) {
fatal(
"%s: \"%s\"'s patch #%" PRIu32 " has unknown type 0x%02x",
fileName,
sectName.c_str(),
patchID,
type
);
} else {
patch.type = PatchType(type);
}
uint32_t rpnSize;
tryReadLong(
@@ -281,13 +313,6 @@ static void readPatch(
}
}
// Sets a patch's `pcSection` from its `pcSectionID`.
static void
linkPatchToPCSect(Patch &patch, std::vector<std::unique_ptr<Section>> const &fileSections) {
patch.pcSection =
patch.pcSectionID != UINT32_MAX ? fileSections[patch.pcSectionID].get() : nullptr;
}
// Reads a section from a file.
static void readSection(
FILE *file, Section &section, char const *fileName, std::vector<FileStackNode> const &fileNodes
@@ -296,11 +321,16 @@ static void readSection(
uint8_t byte;
tryReadString(section.name, file, "%s: Cannot read section name: %s", fileName);
uint32_t nodeID;
tryReadLong(
nodeID, file, "%s: Cannot read \"%s\"'s node ID: %s", fileName, section.name.c_str()
);
if (nodeID >= fileNodes.size()) {
fatal("%s: \"%s\" has invalid node ID #%" PRIu32, fileName, section.name.c_str(), nodeID);
}
section.src = &fileNodes[nodeID];
tryReadLong(
section.lineNo,
file,
@@ -310,18 +340,23 @@ static void readSection(
);
tryReadLong(tmp, file, "%s: Cannot read \"%s\"'s' size: %s", fileName, section.name.c_str());
if (tmp < 0 || tmp > UINT16_MAX) {
fatal("\"%s\"'s section size ($%" PRIx32 ") is invalid", section.name.c_str(), tmp);
fatal(
"%s: \"%s\"'s section size ($%" PRIx32 ") is invalid",
fileName,
section.name.c_str(),
tmp
);
}
section.size = tmp;
section.offset = 0;
tryGetc(
uint8_t, byte, file, "%s: Cannot read \"%s\"'s type: %s", fileName, section.name.c_str()
);
tryGetc(byte, file, "%s: Cannot read \"%s\"'s type: %s", fileName, section.name.c_str());
if (uint8_t type = byte & SECTTYPE_TYPE_MASK; type >= SECTTYPE_INVALID) {
fatal("\"%s\" has unknown section type 0x%02x", section.name.c_str(), type);
fatal("%s: \"%s\" has unknown section type 0x%02x", fileName, section.name.c_str(), type);
} else {
section.type = SectionType(type);
}
if (byte & (1 << SECTTYPE_UNION_BIT)) {
section.modifier = SECTION_UNION;
} else if (byte & (1 << SECTTYPE_FRAGMENT_BIT)) {
@@ -339,14 +374,7 @@ static void readSection(
tryReadLong(tmp, file, "%s: Cannot read \"%s\"'s bank: %s", fileName, section.name.c_str());
section.isBankFixed = tmp >= 0;
section.bank = tmp;
tryGetc(
uint8_t,
byte,
file,
"%s: Cannot read \"%s\"'s alignment: %s",
fileName,
section.name.c_str()
);
tryGetc(byte, file, "%s: Cannot read \"%s\"'s alignment: %s", fileName, section.name.c_str());
if (byte > 16) {
byte = 16;
}
@@ -499,7 +527,16 @@ void obj_ReadFile(std::string const &filePath, size_t fileID) {
readSymbol(file, sym, fileName, nodes[fileID]);
sym_AddSymbol(sym);
if (std::holds_alternative<Label>(sym.data)) {
++nbSymPerSect[std::get<Label>(sym.data).sectionID];
int32_t sectionID = std::get<Label>(sym.data).sectionID;
if (sectionID < 0 || static_cast<size_t>(sectionID) >= nbSymPerSect.size()) {
fatal(
"%s: `%s` has invalid section ID #%" PRId32,
fileName,
sym.name.c_str(),
sectionID
);
}
++nbSymPerSect[sectionID];
}
}
@@ -522,15 +559,41 @@ void obj_ReadFile(std::string const &filePath, size_t fileID) {
Assertion &assertion = patch_AddAssertion();
readAssertion(file, assertion, fileName, i, nodes[fileID]);
linkPatchToPCSect(assertion.patch, fileSections);
if (assertion.patch.pcSectionID == UINT32_MAX) {
assertion.patch.pcSection = nullptr;
} else if (assertion.patch.pcSectionID >= fileSections.size()) {
fatal(
"%s: Assertion #%" PRIu32 "'s patch has invalid section ID #%" PRIu32,
fileName,
i,
assertion.patch.pcSectionID
);
} else {
assertion.patch.pcSection = fileSections[assertion.patch.pcSectionID].get();
}
assertion.fileSymbols = &fileSymbols;
}
// Give patches' PC section pointers to their sections
for (std::unique_ptr<Section> const &sect : fileSections) {
if (sectTypeHasData(sect->type)) {
for (Patch &patch : sect->patches) {
linkPatchToPCSect(patch, fileSections);
if (!sectTypeHasData(sect->type)) {
continue;
}
for (size_t i = 0; i < sect->patches.size(); ++i) {
if (Patch &patch = sect->patches[i]; patch.pcSectionID == UINT32_MAX) {
patch.pcSection = nullptr;
} else if (patch.pcSectionID >= fileSections.size()) {
fatal(
"%s: \"%s\"'s patch #%zu has invalid section ID #%" PRIu32,
fileName,
sect->name.c_str(),
i,
patch.pcSectionID
);
} else {
patch.pcSection = fileSections[patch.pcSectionID].get();
}
}
}

View File

@@ -78,7 +78,8 @@ static uint32_t getRPNByte(uint8_t const *&expression, int32_t &size, Patch cons
}
static Symbol const *getSymbol(std::vector<Symbol> const &symbolList, uint32_t index) {
assume(index != UINT32_MAX); // PC needs to be handled specially, not here
assume(index != UINT32_MAX); // PC needs to be handled specially, not here
assume(index < symbolList.size()); // This needs to be checked before calling
Symbol const &symbol = symbolList[index];
// If the symbol is defined elsewhere...
@@ -270,17 +271,19 @@ static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fil
value = op_shift_right_unsigned(popRPN(patch), value);
break;
case RPN_BANK_SYM:
value = 0;
case RPN_BANK_SYM: {
uint32_t symID = 0;
for (uint8_t shift = 0; shift < 32; shift += 8) {
value |= getRPNByte(expression, size, patch) << shift;
symID |= getRPNByte(expression, size, patch) << shift;
}
if (Symbol const *symbol = getSymbol(fileSymbols, value); !symbol) {
if (symID >= fileSymbols.size()) {
fatalAt(patch, "Requested `BANK()` of invalid symbol ID #%" PRIu32, symID);
} else if (Symbol const *symbol = getSymbol(fileSymbols, symID); !symbol) {
errorAt(
patch,
"Requested `BANK()` of undefined symbol `%s`",
fileSymbols[value].name.c_str()
fileSymbols[symID].name.c_str()
);
isError = true;
value = 1;
@@ -290,12 +293,13 @@ static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fil
errorAt(
patch,
"Requested `BANK()` of non-label symbol `%s`",
fileSymbols[value].name.c_str()
fileSymbols[symID].name.c_str()
);
isError = true;
value = 1;
}
break;
}
case RPN_BANK_SECT: {
// `expression` is not guaranteed to be '\0'-terminated. If it is not,
@@ -420,13 +424,13 @@ static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fil
}
break;
case RPN_SYM:
value = 0;
case RPN_SYM: {
uint32_t symID = 0;
for (uint8_t shift = 0; shift < 32; shift += 8) {
value |= getRPNByte(expression, size, patch) << shift;
symID |= getRPNByte(expression, size, patch) << shift;
}
if (value == -1) { // PC
if (symID == UINT32_MAX) { // PC
if (patch.pcSection) {
value = patch.pcOffset + patch.pcSection->org;
} else {
@@ -434,9 +438,11 @@ static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fil
value = 0;
isError = true;
}
} else if (Symbol const *symbol = getSymbol(fileSymbols, value); !symbol) {
errorAt(patch, "Undefined symbol `%s`", fileSymbols[value].name.c_str());
sym_TraceLocalAliasedSymbols(fileSymbols[value].name);
} else if (symID >= fileSymbols.size()) {
fatalAt(patch, "Invalid symbol ID #%" PRIu32, symID);
} else if (Symbol const *symbol = getSymbol(fileSymbols, symID); !symbol) {
errorAt(patch, "Undefined symbol `%s`", fileSymbols[symID].name.c_str());
sym_TraceLocalAliasedSymbols(fileSymbols[symID].name);
value = 0;
isError = true;
} else if (std::holds_alternative<Label>(symbol->data)) {
@@ -447,6 +453,7 @@ static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fil
}
break;
}
}
pushRPN(value, isError);
}