Fix some rgblink object file input bugs found via fuzzing with AFL++ (#1867)

- ID numbers (for fstack nodes, sections, symbols, patches, etc)
  might be too large for their associated collection
- Enum values might be invalid
- Bank values might be out of range for their section types
This commit is contained in:
Rangi
2025-12-04 20:49:16 -05:00
committed by GitHub
parent 8d6c617875
commit 131bb97ebc
4 changed files with 138 additions and 59 deletions

View File

@@ -120,7 +120,7 @@ enum SectionModifier { SECTION_NORMAL, SECTION_UNION, SECTION_FRAGMENT };
extern char const * const sectionModNames[]; extern char const * const sectionModNames[];
enum ExportLevel { SYMTYPE_LOCAL, SYMTYPE_IMPORT, SYMTYPE_EXPORT }; enum ExportLevel { SYMTYPE_LOCAL, SYMTYPE_IMPORT, SYMTYPE_EXPORT, SYMTYPE_INVALID };
enum PatchType { enum PatchType {
PATCHTYPE_BYTE, PATCHTYPE_BYTE,

View File

@@ -119,6 +119,15 @@ static std::optional<size_t> getPlacement(Section const &section, MemoryLocation
SectionTypeInfo const &typeInfo = sectionTypeInfo[section.type]; SectionTypeInfo const &typeInfo = sectionTypeInfo[section.type];
for (;;) { for (;;) {
if (location.bank < typeInfo.firstBank
|| location.bank >= memory[section.type].size() + typeInfo.firstBank) {
fatal(
"Invalid bank for %s section: 0x%02x",
sectionTypeInfo[section.type].name.c_str(),
location.bank
);
}
// Switch to the beginning of the next bank // Switch to the beginning of the next bank
std::deque<FreeSpace> &bankMem = memory[section.type][location.bank - typeInfo.firstBank]; std::deque<FreeSpace> &bankMem = memory[section.type][location.bank - typeInfo.firstBank];
size_t spaceIdx = 0; size_t spaceIdx = 0;

View File

@@ -72,7 +72,7 @@ static int64_t readLong(FILE *file) {
tryRead(readLong, int64_t, INT64_MAX, long, var, file, __VA_ARGS__) tryRead(readLong, int64_t, INT64_MAX, long, var, file, __VA_ARGS__)
// Helper macro to read a byte from a file to a var, or error out if it fails to. // Helper macro to read a byte from a file to a var, or error out if it fails to.
#define tryGetc(type, var, file, ...) tryRead(getc, int, EOF, type, var, file, __VA_ARGS__) #define tryGetc(var, file, ...) tryRead(getc, int, EOF, uint8_t, var, file, __VA_ARGS__)
// Helper macro to read a '\0'-terminated string from a file, or error out if it fails to. // Helper macro to read a '\0'-terminated string from a file, or error out if it fails to.
#define tryReadString(var, file, ...) \ #define tryReadString(var, file, ...) \
@@ -100,27 +100,31 @@ static void readFileStackNode(
tryReadLong( tryReadLong(
parentID, file, "%s: Cannot read node #%" PRIu32 "'s parent ID: %s", fileName, nodeID parentID, file, "%s: Cannot read node #%" PRIu32 "'s parent ID: %s", fileName, nodeID
); );
node.parent = parentID != UINT32_MAX ? &fileNodes[parentID] : nullptr; if (parentID == UINT32_MAX) {
node.parent = nullptr;
} else if (parentID >= fileNodes.size()) {
fatal("%s: Node #%" PRIu32 " has invalid parent ID #%" PRIu32, fileName, nodeID, parentID);
} else {
node.parent = &fileNodes[parentID];
}
tryReadLong( tryReadLong(
node.lineNo, file, "%s: Cannot read node #%" PRIu32 "'s line number: %s", fileName, nodeID node.lineNo, file, "%s: Cannot read node #%" PRIu32 "'s line number: %s", fileName, nodeID
); );
uint8_t type; uint8_t type;
tryGetc(uint8_t, type, file, "%s: Cannot read node #%" PRIu32 "'s type: %s", fileName, nodeID); tryGetc(type, file, "%s: Cannot read node #%" PRIu32 "'s type: %s", fileName, nodeID);
node.type = static_cast<FileStackNodeType>(type & ~(1 << FSTACKNODE_QUIET_BIT)); switch (type & ~(1 << FSTACKNODE_QUIET_BIT)) {
node.isQuiet = (type & (1 << FSTACKNODE_QUIET_BIT)) != 0;
switch (node.type) {
case NODE_FILE: case NODE_FILE:
case NODE_MACRO: case NODE_MACRO:
node.type = FileStackNodeType(type);
node.data = ""; node.data = "";
tryReadString( tryReadString(
node.name(), file, "%s: Cannot read node #%" PRIu32 "'s file name: %s", fileName, nodeID node.name(), file, "%s: Cannot read node #%" PRIu32 "'s file name: %s", fileName, nodeID
); );
break; break;
case NODE_REPT: { case NODE_REPT: {
node.type = NODE_REPT;
uint32_t depth; uint32_t depth;
tryReadLong( tryReadLong(
depth, file, "%s: Cannot read node #%" PRIu32 "'s REPT depth: %s", fileName, nodeID depth, file, "%s: Cannot read node #%" PRIu32 "'s REPT depth: %s", fileName, nodeID
@@ -143,8 +147,13 @@ static void readFileStackNode(
nodeID nodeID
); );
} }
break;
} }
default:
fatal("%s: Node #%" PRIu32 " has unknown type 0x%02x", fileName, nodeID, type);
} }
node.isQuiet = (type & (1 << FSTACKNODE_QUIET_BIT)) != 0;
} }
// Reads a symbol from a file. // Reads a symbol from a file.
@@ -152,20 +161,25 @@ static void readSymbol(
FILE *file, Symbol &symbol, char const *fileName, std::vector<FileStackNode> const &fileNodes FILE *file, Symbol &symbol, char const *fileName, std::vector<FileStackNode> const &fileNodes
) { ) {
tryReadString(symbol.name, file, "%s: Cannot read symbol name: %s", fileName); tryReadString(symbol.name, file, "%s: Cannot read symbol name: %s", fileName);
tryGetc(
ExportLevel, uint8_t type;
symbol.type, tryGetc(type, file, "%s: Cannot read `%s`'s type: %s", fileName, symbol.name.c_str());
file, if (type >= SYMTYPE_INVALID) {
"%s: Cannot read `%s`'s type: %s", fatal("%s: `%s` has unknown type 0x%02x", fileName, symbol.name.c_str(), type);
fileName, } else {
symbol.name.c_str() symbol.type = ExportLevel(type);
); }
// If the symbol is defined in this file, read its definition // If the symbol is defined in this file, read its definition
if (symbol.type != SYMTYPE_IMPORT) { if (symbol.type != SYMTYPE_IMPORT) {
uint32_t nodeID; uint32_t nodeID;
tryReadLong( tryReadLong(
nodeID, file, "%s: Cannot read `%s`'s node ID: %s", fileName, symbol.name.c_str() nodeID, file, "%s: Cannot read `%s`'s node ID: %s", fileName, symbol.name.c_str()
); );
if (nodeID >= fileNodes.size()) {
fatal("%s: `%s` has invalid node ID #%" PRIu32, fileName, symbol.name.c_str(), nodeID);
}
symbol.src = &fileNodes[nodeID]; symbol.src = &fileNodes[nodeID];
tryReadLong( tryReadLong(
symbol.lineNo, symbol.lineNo,
@@ -212,6 +226,15 @@ static void readPatch(
sectName.c_str(), sectName.c_str(),
patchID patchID
); );
if (nodeID >= fileNodes.size()) {
fatal(
"%s: \"%s\"'s patch #%" PRIu32 " has invalid node ID #%" PRIu32,
fileName,
sectName.c_str(),
patchID,
nodeID
);
}
patch.src = &fileNodes[nodeID]; patch.src = &fileNodes[nodeID];
tryReadLong( tryReadLong(
@@ -247,9 +270,8 @@ static void readPatch(
patchID patchID
); );
PatchType type; uint8_t type;
tryGetc( tryGetc(
PatchType,
type, type,
file, file,
"%s: Cannot read \"%s\"'s patch #%" PRIu32 "'s type: %s", "%s: Cannot read \"%s\"'s patch #%" PRIu32 "'s type: %s",
@@ -257,7 +279,17 @@ static void readPatch(
sectName.c_str(), sectName.c_str(),
patchID patchID
); );
patch.type = type; if (type >= PATCHTYPE_INVALID) {
fatal(
"%s: \"%s\"'s patch #%" PRIu32 " has unknown type 0x%02x",
fileName,
sectName.c_str(),
patchID,
type
);
} else {
patch.type = PatchType(type);
}
uint32_t rpnSize; uint32_t rpnSize;
tryReadLong( tryReadLong(
@@ -281,13 +313,6 @@ static void readPatch(
} }
} }
// Sets a patch's `pcSection` from its `pcSectionID`.
static void
linkPatchToPCSect(Patch &patch, std::vector<std::unique_ptr<Section>> const &fileSections) {
patch.pcSection =
patch.pcSectionID != UINT32_MAX ? fileSections[patch.pcSectionID].get() : nullptr;
}
// Reads a section from a file. // Reads a section from a file.
static void readSection( static void readSection(
FILE *file, Section &section, char const *fileName, std::vector<FileStackNode> const &fileNodes FILE *file, Section &section, char const *fileName, std::vector<FileStackNode> const &fileNodes
@@ -296,11 +321,16 @@ static void readSection(
uint8_t byte; uint8_t byte;
tryReadString(section.name, file, "%s: Cannot read section name: %s", fileName); tryReadString(section.name, file, "%s: Cannot read section name: %s", fileName);
uint32_t nodeID; uint32_t nodeID;
tryReadLong( tryReadLong(
nodeID, file, "%s: Cannot read \"%s\"'s node ID: %s", fileName, section.name.c_str() nodeID, file, "%s: Cannot read \"%s\"'s node ID: %s", fileName, section.name.c_str()
); );
if (nodeID >= fileNodes.size()) {
fatal("%s: \"%s\" has invalid node ID #%" PRIu32, fileName, section.name.c_str(), nodeID);
}
section.src = &fileNodes[nodeID]; section.src = &fileNodes[nodeID];
tryReadLong( tryReadLong(
section.lineNo, section.lineNo,
file, file,
@@ -310,18 +340,23 @@ static void readSection(
); );
tryReadLong(tmp, file, "%s: Cannot read \"%s\"'s' size: %s", fileName, section.name.c_str()); tryReadLong(tmp, file, "%s: Cannot read \"%s\"'s' size: %s", fileName, section.name.c_str());
if (tmp < 0 || tmp > UINT16_MAX) { if (tmp < 0 || tmp > UINT16_MAX) {
fatal("\"%s\"'s section size ($%" PRIx32 ") is invalid", section.name.c_str(), tmp); fatal(
"%s: \"%s\"'s section size ($%" PRIx32 ") is invalid",
fileName,
section.name.c_str(),
tmp
);
} }
section.size = tmp; section.size = tmp;
section.offset = 0; section.offset = 0;
tryGetc(
uint8_t, byte, file, "%s: Cannot read \"%s\"'s type: %s", fileName, section.name.c_str() tryGetc(byte, file, "%s: Cannot read \"%s\"'s type: %s", fileName, section.name.c_str());
);
if (uint8_t type = byte & SECTTYPE_TYPE_MASK; type >= SECTTYPE_INVALID) { if (uint8_t type = byte & SECTTYPE_TYPE_MASK; type >= SECTTYPE_INVALID) {
fatal("\"%s\" has unknown section type 0x%02x", section.name.c_str(), type); fatal("%s: \"%s\" has unknown section type 0x%02x", fileName, section.name.c_str(), type);
} else { } else {
section.type = SectionType(type); section.type = SectionType(type);
} }
if (byte & (1 << SECTTYPE_UNION_BIT)) { if (byte & (1 << SECTTYPE_UNION_BIT)) {
section.modifier = SECTION_UNION; section.modifier = SECTION_UNION;
} else if (byte & (1 << SECTTYPE_FRAGMENT_BIT)) { } else if (byte & (1 << SECTTYPE_FRAGMENT_BIT)) {
@@ -339,14 +374,7 @@ static void readSection(
tryReadLong(tmp, file, "%s: Cannot read \"%s\"'s bank: %s", fileName, section.name.c_str()); tryReadLong(tmp, file, "%s: Cannot read \"%s\"'s bank: %s", fileName, section.name.c_str());
section.isBankFixed = tmp >= 0; section.isBankFixed = tmp >= 0;
section.bank = tmp; section.bank = tmp;
tryGetc( tryGetc(byte, file, "%s: Cannot read \"%s\"'s alignment: %s", fileName, section.name.c_str());
uint8_t,
byte,
file,
"%s: Cannot read \"%s\"'s alignment: %s",
fileName,
section.name.c_str()
);
if (byte > 16) { if (byte > 16) {
byte = 16; byte = 16;
} }
@@ -499,7 +527,16 @@ void obj_ReadFile(std::string const &filePath, size_t fileID) {
readSymbol(file, sym, fileName, nodes[fileID]); readSymbol(file, sym, fileName, nodes[fileID]);
sym_AddSymbol(sym); sym_AddSymbol(sym);
if (std::holds_alternative<Label>(sym.data)) { if (std::holds_alternative<Label>(sym.data)) {
++nbSymPerSect[std::get<Label>(sym.data).sectionID]; int32_t sectionID = std::get<Label>(sym.data).sectionID;
if (sectionID < 0 || static_cast<size_t>(sectionID) >= nbSymPerSect.size()) {
fatal(
"%s: `%s` has invalid section ID #%" PRId32,
fileName,
sym.name.c_str(),
sectionID
);
}
++nbSymPerSect[sectionID];
} }
} }
@@ -522,15 +559,41 @@ void obj_ReadFile(std::string const &filePath, size_t fileID) {
Assertion &assertion = patch_AddAssertion(); Assertion &assertion = patch_AddAssertion();
readAssertion(file, assertion, fileName, i, nodes[fileID]); readAssertion(file, assertion, fileName, i, nodes[fileID]);
linkPatchToPCSect(assertion.patch, fileSections);
if (assertion.patch.pcSectionID == UINT32_MAX) {
assertion.patch.pcSection = nullptr;
} else if (assertion.patch.pcSectionID >= fileSections.size()) {
fatal(
"%s: Assertion #%" PRIu32 "'s patch has invalid section ID #%" PRIu32,
fileName,
i,
assertion.patch.pcSectionID
);
} else {
assertion.patch.pcSection = fileSections[assertion.patch.pcSectionID].get();
}
assertion.fileSymbols = &fileSymbols; assertion.fileSymbols = &fileSymbols;
} }
// Give patches' PC section pointers to their sections // Give patches' PC section pointers to their sections
for (std::unique_ptr<Section> const &sect : fileSections) { for (std::unique_ptr<Section> const &sect : fileSections) {
if (sectTypeHasData(sect->type)) { if (!sectTypeHasData(sect->type)) {
for (Patch &patch : sect->patches) { continue;
linkPatchToPCSect(patch, fileSections); }
for (size_t i = 0; i < sect->patches.size(); ++i) {
if (Patch &patch = sect->patches[i]; patch.pcSectionID == UINT32_MAX) {
patch.pcSection = nullptr;
} else if (patch.pcSectionID >= fileSections.size()) {
fatal(
"%s: \"%s\"'s patch #%zu has invalid section ID #%" PRIu32,
fileName,
sect->name.c_str(),
i,
patch.pcSectionID
);
} else {
patch.pcSection = fileSections[patch.pcSectionID].get();
} }
} }
} }

View File

@@ -78,7 +78,8 @@ static uint32_t getRPNByte(uint8_t const *&expression, int32_t &size, Patch cons
} }
static Symbol const *getSymbol(std::vector<Symbol> const &symbolList, uint32_t index) { static Symbol const *getSymbol(std::vector<Symbol> const &symbolList, uint32_t index) {
assume(index != UINT32_MAX); // PC needs to be handled specially, not here assume(index != UINT32_MAX); // PC needs to be handled specially, not here
assume(index < symbolList.size()); // This needs to be checked before calling
Symbol const &symbol = symbolList[index]; Symbol const &symbol = symbolList[index];
// If the symbol is defined elsewhere... // If the symbol is defined elsewhere...
@@ -270,17 +271,19 @@ static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fil
value = op_shift_right_unsigned(popRPN(patch), value); value = op_shift_right_unsigned(popRPN(patch), value);
break; break;
case RPN_BANK_SYM: case RPN_BANK_SYM: {
value = 0; uint32_t symID = 0;
for (uint8_t shift = 0; shift < 32; shift += 8) { for (uint8_t shift = 0; shift < 32; shift += 8) {
value |= getRPNByte(expression, size, patch) << shift; symID |= getRPNByte(expression, size, patch) << shift;
} }
if (Symbol const *symbol = getSymbol(fileSymbols, value); !symbol) { if (symID >= fileSymbols.size()) {
fatalAt(patch, "Requested `BANK()` of invalid symbol ID #%" PRIu32, symID);
} else if (Symbol const *symbol = getSymbol(fileSymbols, symID); !symbol) {
errorAt( errorAt(
patch, patch,
"Requested `BANK()` of undefined symbol `%s`", "Requested `BANK()` of undefined symbol `%s`",
fileSymbols[value].name.c_str() fileSymbols[symID].name.c_str()
); );
isError = true; isError = true;
value = 1; value = 1;
@@ -290,12 +293,13 @@ static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fil
errorAt( errorAt(
patch, patch,
"Requested `BANK()` of non-label symbol `%s`", "Requested `BANK()` of non-label symbol `%s`",
fileSymbols[value].name.c_str() fileSymbols[symID].name.c_str()
); );
isError = true; isError = true;
value = 1; value = 1;
} }
break; break;
}
case RPN_BANK_SECT: { case RPN_BANK_SECT: {
// `expression` is not guaranteed to be '\0'-terminated. If it is not, // `expression` is not guaranteed to be '\0'-terminated. If it is not,
@@ -420,13 +424,13 @@ static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fil
} }
break; break;
case RPN_SYM: case RPN_SYM: {
value = 0; uint32_t symID = 0;
for (uint8_t shift = 0; shift < 32; shift += 8) { for (uint8_t shift = 0; shift < 32; shift += 8) {
value |= getRPNByte(expression, size, patch) << shift; symID |= getRPNByte(expression, size, patch) << shift;
} }
if (value == -1) { // PC if (symID == UINT32_MAX) { // PC
if (patch.pcSection) { if (patch.pcSection) {
value = patch.pcOffset + patch.pcSection->org; value = patch.pcOffset + patch.pcSection->org;
} else { } else {
@@ -434,9 +438,11 @@ static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fil
value = 0; value = 0;
isError = true; isError = true;
} }
} else if (Symbol const *symbol = getSymbol(fileSymbols, value); !symbol) { } else if (symID >= fileSymbols.size()) {
errorAt(patch, "Undefined symbol `%s`", fileSymbols[value].name.c_str()); fatalAt(patch, "Invalid symbol ID #%" PRIu32, symID);
sym_TraceLocalAliasedSymbols(fileSymbols[value].name); } else if (Symbol const *symbol = getSymbol(fileSymbols, symID); !symbol) {
errorAt(patch, "Undefined symbol `%s`", fileSymbols[symID].name.c_str());
sym_TraceLocalAliasedSymbols(fileSymbols[symID].name);
value = 0; value = 0;
isError = true; isError = true;
} else if (std::holds_alternative<Label>(symbol->data)) { } else if (std::holds_alternative<Label>(symbol->data)) {
@@ -447,6 +453,7 @@ static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fil
} }
break; break;
} }
}
pushRPN(value, isError); pushRPN(value, isError);
} }