Files
rgbds/src/link/patch.cpp

597 lines
15 KiB
C++

// SPDX-License-Identifier: MIT
#include "link/patch.hpp"
#include <deque>
#include <inttypes.h>
#include <limits.h>
#include <stdint.h>
#include <variant>
#include <vector>
#include "diagnostics.hpp"
#include "helpers.hpp" // assume
#include "linkdefs.hpp"
#include "opmath.hpp"
#include "verbosity.hpp"
#include "link/section.hpp"
#include "link/symbol.hpp"
#include "link/warning.hpp"
static std::deque<Assertion> assertions;
struct RPNStackEntry {
int32_t value;
bool errorFlag; // Whether the value is a placeholder inserted for error recovery
};
static std::deque<RPNStackEntry> rpnStack;
static void pushRPN(int32_t value, bool comesFromError) {
rpnStack.push_front({.value = value, .errorFlag = comesFromError});
}
// This flag tracks whether the RPN op that is currently being evaluated
// has popped any values with the error flag set.
static bool isError = false;
#define diagnosticAt(patch, id, ...) \
do { \
bool errorDiag = warnings.getWarningBehavior(id) == WarningBehavior::ERROR; \
if (!isError || !errorDiag) { \
warningAt(patch, id, __VA_ARGS__); \
} \
if (errorDiag) { \
isError = true; \
} \
} while (0)
#define firstErrorAt(...) \
do { \
if (!isError) { \
errorAt(__VA_ARGS__); \
isError = true; \
} \
} while (0)
static int32_t popRPN(Patch const &patch) {
if (rpnStack.empty()) {
fatalAt(patch, "Internal error, RPN stack empty");
}
RPNStackEntry entry = rpnStack.front();
rpnStack.pop_front();
isError |= entry.errorFlag;
return entry.value;
}
// RPN operators
static uint32_t getRPNByte(uint8_t const *&expression, int32_t &size, Patch const &patch) {
if (!size--) {
fatalAt(patch, "Internal error, RPN expression overread");
}
return *expression++;
}
static Symbol const *getSymbol(std::vector<Symbol> const &symbolList, uint32_t index) {
assume(index != UINT32_MAX); // PC needs to be handled specially, not here
Symbol const &symbol = symbolList[index];
// If the symbol is defined elsewhere...
if (symbol.type == SYMTYPE_IMPORT) {
return sym_GetSymbol(symbol.name);
}
return &symbol;
}
// Compute a patch's value from its RPN string.
static int32_t computeRPNExpr(Patch const &patch, std::vector<Symbol> const &fileSymbols) {
uint8_t const *expression = patch.rpnExpression.data();
int32_t size = static_cast<int32_t>(patch.rpnExpression.size());
rpnStack.clear();
while (size > 0) {
RPNCommand command = static_cast<RPNCommand>(getRPNByte(expression, size, patch));
isError = false;
// Be VERY careful with two `popRPN` in the same expression.
// C++ does not guarantee order of evaluation of operands!
// So, if there are two `popRPN` in the same expression, make
// sure the operation is commutative.
int32_t value;
switch (command) {
case RPN_ADD:
value = popRPN(patch) + popRPN(patch);
break;
case RPN_SUB:
value = popRPN(patch);
value = popRPN(patch) - value;
break;
case RPN_MUL:
value = popRPN(patch) * popRPN(patch);
break;
case RPN_DIV:
value = popRPN(patch);
if (value == 0) {
firstErrorAt(patch, "Division by 0");
popRPN(patch);
value = 0;
} else if (int32_t lval = popRPN(patch); lval == INT32_MIN && value == -1) {
diagnosticAt(
patch,
WARNING_DIV,
"Division of %" PRId32 " by -1 yields %" PRId32,
INT32_MIN,
INT32_MIN
);
value = INT32_MIN;
} else {
value = op_divide(lval, value);
}
break;
case RPN_MOD:
value = popRPN(patch);
if (value == 0) {
firstErrorAt(patch, "Modulo by 0");
popRPN(patch);
value = 0;
} else {
value = op_modulo(popRPN(patch), value);
}
break;
case RPN_NEG:
value = op_neg(popRPN(patch));
break;
case RPN_EXP:
value = popRPN(patch);
if (value < 0) {
firstErrorAt(patch, "Exponent by negative value %" PRId32, value);
popRPN(patch);
value = 0;
} else {
value = op_exponent(popRPN(patch), value);
}
break;
case RPN_HIGH:
value = op_high(popRPN(patch));
break;
case RPN_LOW:
value = op_low(popRPN(patch));
break;
case RPN_BITWIDTH:
value = op_bitwidth(popRPN(patch));
break;
case RPN_TZCOUNT:
value = op_tzcount(popRPN(patch));
break;
case RPN_OR:
value = popRPN(patch) | popRPN(patch);
break;
case RPN_AND:
value = popRPN(patch) & popRPN(patch);
break;
case RPN_XOR:
value = popRPN(patch) ^ popRPN(patch);
break;
case RPN_NOT:
value = ~popRPN(patch);
break;
case RPN_LOGAND:
value = popRPN(patch);
value = popRPN(patch) && value;
break;
case RPN_LOGOR:
value = popRPN(patch);
value = popRPN(patch) || value;
break;
case RPN_LOGNOT:
value = !popRPN(patch);
break;
case RPN_LOGEQ:
value = popRPN(patch) == popRPN(patch);
break;
case RPN_LOGNE:
value = popRPN(patch) != popRPN(patch);
break;
case RPN_LOGGT:
value = popRPN(patch);
value = popRPN(patch) > value;
break;
case RPN_LOGLT:
value = popRPN(patch);
value = popRPN(patch) < value;
break;
case RPN_LOGGE:
value = popRPN(patch);
value = popRPN(patch) >= value;
break;
case RPN_LOGLE:
value = popRPN(patch);
value = popRPN(patch) <= value;
break;
case RPN_SHL:
value = popRPN(patch);
if (value < 0) {
diagnosticAt(
patch, WARNING_SHIFT_AMOUNT, "Shifting left by negative amount %" PRId32, value
);
}
if (value >= 32) {
diagnosticAt(
patch, WARNING_SHIFT_AMOUNT, "Shifting left by large amount %" PRId32, value
);
}
value = op_shift_left(popRPN(patch), value);
break;
case RPN_SHR: {
value = popRPN(patch);
int32_t lval = popRPN(patch);
if (lval < 0) {
diagnosticAt(patch, WARNING_SHIFT, "Shifting right negative value %" PRId32, lval);
}
if (value < 0) {
diagnosticAt(
patch, WARNING_SHIFT_AMOUNT, "Shifting right by negative amount %" PRId32, value
);
}
if (value >= 32) {
diagnosticAt(
patch, WARNING_SHIFT_AMOUNT, "Shifting right by large amount %" PRId32, value
);
}
value = op_shift_right(lval, value);
break;
}
case RPN_USHR:
value = popRPN(patch);
if (value < 0) {
diagnosticAt(
patch, WARNING_SHIFT_AMOUNT, "Shifting right by negative amount %" PRId32, value
);
}
if (value >= 32) {
diagnosticAt(
patch, WARNING_SHIFT_AMOUNT, "Shifting right by large amount %" PRId32, value
);
}
value = op_shift_right_unsigned(popRPN(patch), value);
break;
case RPN_BANK_SYM:
value = 0;
for (uint8_t shift = 0; shift < 32; shift += 8) {
value |= getRPNByte(expression, size, patch) << shift;
}
if (Symbol const *symbol = getSymbol(fileSymbols, value); !symbol) {
errorAt(
patch,
"Requested `BANK()` of undefined symbol `%s`",
fileSymbols[value].name.c_str()
);
isError = true;
value = 1;
} else if (std::holds_alternative<Label>(symbol->data)) {
value = std::get<Label>(symbol->data).section->bank;
} else {
errorAt(
patch,
"Requested `BANK()` of non-label symbol `%s`",
fileSymbols[value].name.c_str()
);
isError = true;
value = 1;
}
break;
case RPN_BANK_SECT: {
// `expression` is not guaranteed to be '\0'-terminated. If it is not,
// `getRPNByte` will have a fatal internal error.
char const *name = reinterpret_cast<char const *>(expression);
while (getRPNByte(expression, size, patch)) {}
if (Section const *sect = sect_GetSection(name); !sect) {
errorAt(patch, "Requested `BANK()` of undefined section \"%s\"", name);
isError = true;
value = 1;
} else {
value = sect->bank;
}
break;
}
case RPN_BANK_SELF:
if (!patch.pcSection) {
errorAt(patch, "PC has no bank outside of a section");
isError = true;
value = 1;
} else {
value = patch.pcSection->bank;
}
break;
case RPN_SIZEOF_SECT: {
// This has assumptions commented in the `RPN_BANK_SECT` case above.
char const *name = reinterpret_cast<char const *>(expression);
while (getRPNByte(expression, size, patch)) {}
if (Section const *sect = sect_GetSection(name); !sect) {
errorAt(patch, "Requested `SIZEOF()` of undefined section \"%s\"", name);
isError = true;
value = 1;
} else {
value = sect->size;
}
break;
}
case RPN_STARTOF_SECT: {
// This has assumptions commented in the `RPN_BANK_SECT` case above.
char const *name = reinterpret_cast<char const *>(expression);
while (getRPNByte(expression, size, patch)) {}
if (Section const *sect = sect_GetSection(name); !sect) {
errorAt(patch, "Requested `STARTOF()` of undefined section \"%s\"", name);
isError = true;
value = 1;
} else {
assume(sect->offset == 0);
value = sect->org;
}
break;
}
case RPN_SIZEOF_SECTTYPE:
value = getRPNByte(expression, size, patch);
if (value < 0 || value >= SECTTYPE_INVALID) {
errorAt(patch, "Requested `SIZEOF()` of an invalid section type");
isError = true;
value = 0;
} else {
value = sectionTypeInfo[value].size;
}
break;
case RPN_STARTOF_SECTTYPE:
value = getRPNByte(expression, size, patch);
if (value < 0 || value >= SECTTYPE_INVALID) {
errorAt(patch, "Requested `STARTOF()` of an invalid section type");
isError = true;
value = 0;
} else {
value = sectionTypeInfo[value].startAddr;
}
break;
case RPN_HRAM:
value = popRPN(patch);
if (value < 0xFF00 || value > 0xFFFF) {
firstErrorAt(
patch,
"Address $%" PRIx32 " for `LDH` is not in HRAM range; use `LD` instead",
value
);
value = 0;
}
value &= 0xFF;
break;
case RPN_RST:
value = popRPN(patch);
// Acceptable values are 0x00, 0x08, 0x10, ..., 0x38
if (value & ~0x38) {
firstErrorAt(
patch, "Value $%" PRIx32 " is not a `RST` vector; use `CALL` instead", value
);
value = 0;
}
value |= 0xC7;
break;
case RPN_BIT_INDEX: {
value = popRPN(patch);
int32_t mask = getRPNByte(expression, size, patch);
// Acceptable values are 0 to 7
if (value & ~0x07) {
firstErrorAt(patch, "Value $%" PRIx32 " is not a bit index", value);
value = 0;
}
value = mask | (value << 3);
break;
}
case RPN_CONST:
value = 0;
for (uint8_t shift = 0; shift < 32; shift += 8) {
value |= getRPNByte(expression, size, patch) << shift;
}
break;
case RPN_SYM:
value = 0;
for (uint8_t shift = 0; shift < 32; shift += 8) {
value |= getRPNByte(expression, size, patch) << shift;
}
if (value == -1) { // PC
if (patch.pcSection) {
value = patch.pcOffset + patch.pcSection->org;
} else {
errorAt(patch, "PC has no value outside of a section");
value = 0;
isError = true;
}
} else if (Symbol const *symbol = getSymbol(fileSymbols, value); !symbol) {
errorAt(patch, "Undefined symbol `%s`", fileSymbols[value].name.c_str());
sym_TraceLocalAliasedSymbols(fileSymbols[value].name);
isError = true;
} else if (std::holds_alternative<Label>(symbol->data)) {
Label const &label = std::get<Label>(symbol->data);
value = label.section->org + label.offset;
} else {
value = std::get<int32_t>(symbol->data);
}
break;
}
pushRPN(value, isError);
}
if (rpnStack.size() > 1) {
errorAt(patch, "RPN stack has %zu entries on exit, not 1", rpnStack.size());
}
isError = false;
return popRPN(patch);
}
Assertion &patch_AddAssertion() {
return assertions.emplace_front();
}
void patch_CheckAssertions() {
verbosePrint(VERB_NOTICE, "Checking assertions...\n");
for (Assertion &assert : assertions) {
int32_t value = computeRPNExpr(assert.patch, *assert.fileSymbols);
AssertionType type = static_cast<AssertionType>(assert.patch.type);
if (!isError && !value) {
switch (type) {
case ASSERT_FATAL:
fatalAt(
assert.patch,
"%s",
!assert.message.empty() ? assert.message.c_str() : "assert failure"
);
case ASSERT_ERROR:
errorAt(
assert.patch,
"%s",
!assert.message.empty() ? assert.message.c_str() : "assert failure"
);
break;
case ASSERT_WARN:
warningAt(
assert.patch,
WARNING_ASSERT,
"%s",
!assert.message.empty() ? assert.message.c_str() : "assert failure"
);
break;
}
} else if (isError && type == ASSERT_FATAL) {
fatalAt(
assert.patch,
"Failed to evaluate assertion%s%s",
!assert.message.empty() ? ": " : "",
assert.message.c_str()
);
}
}
}
static void checkPatchSize(Patch const &patch, int32_t v, uint8_t n) {
assume(n != 0); // That doesn't make sense
assume(n < CHAR_BIT * sizeof(int)); // Otherwise `1 << n` is UB
if (v < -(1 << n) || v >= 1 << n) {
diagnosticAt(
patch,
WARNING_TRUNCATION_1,
"Value $%" PRIx32 "%s is not %u-bit",
v,
v < 0 ? " (may be negative?)" : "",
n
);
} else if (v < -(1 << (n - 1))) {
diagnosticAt(
patch,
WARNING_TRUNCATION_2,
"Value $%" PRIx32 "%s is not %u-bit",
v,
v < 0 ? " (may be negative?)" : "",
n
);
}
}
// Applies all of a section's patches to a data section
static void applyFilePatches(Section &section, Section &dataSection) {
verbosePrint(VERB_INFO, "Patching section \"%s\"...\n", section.name.c_str());
for (Patch &patch : section.patches) {
int32_t value = computeRPNExpr(patch, *section.fileSymbols);
uint16_t offset = patch.offset + section.offset;
uint8_t typeSizes[PATCHTYPE_INVALID] = {
1, // PATCHTYPE_BYTE
2, // PATCHTYPE_WORD
4, // PATCHTYPE_LONG
1, // PATCHTYPE_JR
};
uint8_t typeSize = typeSizes[patch.type];
if (dataSection.data.size() < offset + typeSize) {
errorAt(
patch,
"Patch would write %zu bytes past the end of section \"%s\" (%zu bytes long)",
offset + typeSize - dataSection.data.size(),
dataSection.name.c_str(),
dataSection.data.size()
);
} else if (patch.type == PATCHTYPE_JR) { // `jr` is quite unlike the others...
// Offset is relative to the byte *after* the operand
// PC as operand to `jr` is lower than reference PC by 2
uint16_t address = patch.pcSection->org + patch.pcOffset + 2;
int16_t jumpOffset = value - address;
if (jumpOffset < -128 || jumpOffset > 127) {
firstErrorAt(
patch,
"`JR` target must be between -128 and 127 bytes away, not %" PRId16
"; use `JP` instead",
jumpOffset
);
}
dataSection.data[offset] = jumpOffset & 0xFF;
} else {
// Patch a certain number of bytes
if (typeSize < sizeof(int)) {
checkPatchSize(patch, value, typeSize * 8);
}
for (uint8_t i = 0; i < typeSize; ++i) {
dataSection.data[offset + i] = value & 0xFF;
value >>= 8;
}
}
}
}
// Applies all of a section's patches, iterating over "pieces" of unionized sections
static void applyPatches(Section &section) {
if (!sectTypeHasData(section.type)) {
return;
}
for (Section &piece : section.pieces()) {
applyFilePatches(piece, section);
}
}
void patch_ApplyPatches() {
sect_ForEach(applyPatches);
}