Skip to content

Commit

Permalink
Use enum classes for error conditions to make things explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
Ri0n committed Apr 26, 2024
1 parent ab2f925 commit 2a9b7be
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 127 deletions.
124 changes: 63 additions & 61 deletions src/xmpp/xmpp-core/protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <QByteArray>
#include <QList>
#include <QtCrypto>
#include <optional>
#include <qca.h>

using namespace XMPP;
Expand Down Expand Up @@ -95,46 +96,46 @@ StreamFeatures::StreamFeatures()
// BasicProtocol
//----------------------------------------------------------------------------
BasicProtocol::SASLCondEntry BasicProtocol::saslCondTable[] = {
{ "aborted", Aborted },
{ "account-disabled", AccountDisabled },
{ "credentials-expired", CredentialsExpired },
{ "encryption-required", EncryptionRequired },
{ "incorrect-encoding", IncorrectEncoding },
{ "invalid-authzid", InvalidAuthzid },
{ "invalid-mechanism", InvalidMech },
{ "malformed-request", MalformedRequest },
{ "mechanism-too-weak", MechTooWeak },
{ "not-authorized", NotAuthorized },
{ "temporary-auth-failure", TemporaryAuthFailure },
{ nullptr, 0 },
{ "aborted", SASLCond::Aborted },
{ "account-disabled", SASLCond::AccountDisabled },
{ "credentials-expired", SASLCond::CredentialsExpired },
{ "encryption-required", SASLCond::EncryptionRequired },
{ "incorrect-encoding", SASLCond::IncorrectEncoding },
{ "invalid-authzid", SASLCond::InvalidAuthzid },
{ "invalid-mechanism", SASLCond::InvalidMechanism },
{ "malformed-request", SASLCond::MalformedRequest },
{ "mechanism-too-weak", SASLCond::MechanismTooWeak },
{ "not-authorized", SASLCond::NotAuthorized },
{ "temporary-auth-failure", SASLCond::TemporaryAuthFailure },
{ nullptr, SASLCond(0) },
};

BasicProtocol::StreamCondEntry BasicProtocol::streamCondTable[] = {
{ "bad-format", BadFormat },
{ "bad-namespace-prefix", BadNamespacePrefix },
{ "conflict", Conflict },
{ "connection-timeout", ConnectionTimeout },
{ "host-gone", HostGone },
{ "host-unknown", HostUnknown },
{ "improper-addressing", ImproperAddressing },
{ "internal-server-error", InternalServerError },
{ "invalid-from", InvalidFrom },
{ "invalid-namespace", InvalidNamespace },
{ "invalid-xml", InvalidXml },
{ "not-authorized", StreamNotAuthorized },
{ "not-well-formed", NotWellFormed },
{ "policy-violation", PolicyViolation },
{ "remote-connection-failed", RemoteConnectionFailed },
{ "reset", StreamReset },
{ "resource-constraint", ResourceConstraint },
{ "restricted-xml", RestrictedXml },
{ "see-other-host", SeeOtherHost },
{ "system-shutdown", SystemShutdown },
{ "undefined-condition", UndefinedCondition },
{ "unsupported-encoding", UnsupportedEncoding },
{ "unsupported-stanza-type", UnsupportedStanzaType },
{ "unsupported-version", UnsupportedVersion },
{ nullptr, 0 },
{ "bad-format", StreamCond::BadFormat },
{ "bad-namespace-prefix", StreamCond::BadNamespacePrefix },
{ "conflict", StreamCond::Conflict },
{ "connection-timeout", StreamCond::ConnectionTimeout },
{ "host-gone", StreamCond::HostGone },
{ "host-unknown", StreamCond::HostUnknown },
{ "improper-addressing", StreamCond::ImproperAddressing },
{ "internal-server-error", StreamCond::InternalServerError },
{ "invalid-from", StreamCond::InvalidFrom },
{ "invalid-namespace", StreamCond::InvalidNamespace },
{ "invalid-xml", StreamCond::InvalidXml },
{ "not-authorized", StreamCond::NotAuthorized },
{ "not-well-formed", StreamCond::NotWellFormed },
{ "policy-violation", StreamCond::PolicyViolation },
{ "remote-connection-failed", StreamCond::RemoteConnectionFailed },
{ "reset", StreamCond::Reset },
{ "resource-constraint", StreamCond::ResourceConstraint },
{ "restricted-xml", StreamCond::RestrictedXml },
{ "see-other-host", StreamCond::SeeOtherHost },
{ "system-shutdown", StreamCond::SystemShutdown },
{ "undefined-condition", StreamCond::UndefinedCondition },
{ "unsupported-encoding", StreamCond::UnsupportedEncoding },
{ "unsupported-stanza-type", StreamCond::UnsupportedStanzaType },
{ "unsupported-version", StreamCond::UnsupportedVersion },
{ nullptr, StreamCond(0) },
};

BasicProtocol::BasicProtocol() : XmlProtocol() { init(); }
Expand All @@ -143,7 +144,7 @@ BasicProtocol::~BasicProtocol() { }

void BasicProtocol::init()
{
errCond = -1;
errCond = {};
sasl_authed = false;
doShutdown = false;
delayedError = false;
Expand Down Expand Up @@ -210,7 +211,7 @@ QDomElement BasicProtocol::recvStanza()

void BasicProtocol::shutdown() { doShutdown = true; }

void BasicProtocol::shutdownWithError(int cond, const QString &str)
void BasicProtocol::shutdownWithError(StreamCond cond, const QString &str)
{
otherHost = str;
delayErrorAndClose(cond);
Expand All @@ -236,25 +237,25 @@ void BasicProtocol::setSASLNext(const QByteArray &step) { sasl_step = step; }

void BasicProtocol::setSASLAuthed() { sasl_authed = true; }

int BasicProtocol::stringToSASLCond(const QString &s)
std::optional<BasicProtocol::SASLCond> BasicProtocol::stringToSASLCond(const QString &s)
{
for (int n = 0; saslCondTable[n].str; ++n) {
if (s == saslCondTable[n].str)
return saslCondTable[n].cond;
}
return -1;
return {};
}

int BasicProtocol::stringToStreamCond(const QString &s)
std::optional<BasicProtocol::StreamCond> BasicProtocol::stringToStreamCond(const QString &s)
{
for (int n = 0; streamCondTable[n].str; ++n) {
if (s == streamCondTable[n].str)
return streamCondTable[n].cond;
}
return -1;
return {};
}

QString BasicProtocol::saslCondToString(int x)
QString BasicProtocol::saslCondToString(SASLCond x)
{
for (int n = 0; saslCondTable[n].str; ++n) {
if (x == saslCondTable[n].cond)
Expand All @@ -263,7 +264,7 @@ QString BasicProtocol::saslCondToString(int x)
return QString();
}

QString BasicProtocol::streamCondToString(int x)
QString BasicProtocol::streamCondToString(StreamCond x)
{
for (int n = 0; streamCondTable[n].str; ++n) {
if (x == streamCondTable[n].cond)
Expand All @@ -281,13 +282,13 @@ void BasicProtocol::extractStreamError(const QDomElement &e)
QDomElement t = firstChildElement(e);
if (t.isNull() || t.namespaceURI() != NS_STREAMS) {
// probably old-style error
errCond = -1;
errCond = {};
errText = e.text();
} else
errCond = stringToStreamCond(t.tagName());

if (errCond != -1) {
if (errCond == SeeOtherHost)
if (errCond.has_value()) {
if (std::get<StreamCond>(*errCond) == StreamCond::SeeOtherHost)
otherHost = t.text();

auto nodes = e.elementsByTagNameNS(NS_STREAMS, "text");
Expand Down Expand Up @@ -320,7 +321,7 @@ void BasicProtocol::send(const QDomElement &e, bool clip) { writeElement(e, Type

void BasicProtocol::sendUrgent(const QDomElement &e, bool clip) { writeElement(e, TypeElement, false, clip, true); }

void BasicProtocol::sendStreamError(int cond, const QString &text, const QDomElement &appSpec)
void BasicProtocol::sendStreamError(StreamCond cond, const QString &text, const QDomElement &appSpec)
{
QDomElement se = doc.createElementNS(NS_ETHERX, "stream:error");
QDomElement err = doc.createElementNS(NS_STREAMS, streamCondToString(cond));
Expand All @@ -346,7 +347,7 @@ void BasicProtocol::sendStreamError(const QString &text)
writeElement(se, 100, false);
}

bool BasicProtocol::errorAndClose(int cond, const QString &text, const QDomElement &appSpec)
bool BasicProtocol::errorAndClose(StreamCond cond, const QString &text, const QDomElement &appSpec)
{
closeError = true;
errCond = cond;
Expand All @@ -363,7 +364,7 @@ bool BasicProtocol::error(int code)
return true;
}

void BasicProtocol::delayErrorAndClose(int cond, const QString &text, const QDomElement &appSpec)
void BasicProtocol::delayErrorAndClose(StreamCond cond, const QString &text, const QDomElement &appSpec)
{
errorCode = ErrStream;
errCond = cond;
Expand Down Expand Up @@ -415,7 +416,7 @@ void BasicProtocol::handleDocOpen(const Parser::Event &pe)
{
if (isIncoming()) {
if (xmlEncoding() != "UTF-8") {
delayErrorAndClose(UnsupportedEncoding);
delayErrorAndClose(StreamCond::UnsupportedEncoding);
return;
}
}
Expand Down Expand Up @@ -455,7 +456,7 @@ void BasicProtocol::handleDocOpen(const Parser::Event &pe)
handleStreamOpen(pe);
} else {
if (isIncoming())
delayErrorAndClose(BadFormat);
delayErrorAndClose(StreamCond::BadFormat);
else
delayError(ErrProtocol);
}
Expand All @@ -464,7 +465,7 @@ void BasicProtocol::handleDocOpen(const Parser::Event &pe)
bool BasicProtocol::handleError()
{
if (isIncoming())
return errorAndClose(NotWellFormed);
return errorAndClose(StreamCond::NotWellFormed);
else
return error(ErrParse);
}
Expand All @@ -485,7 +486,8 @@ bool BasicProtocol::doStep(const QDomElement &e)
// handle pending error
if (delayedError) {
if (isIncoming())
return errorAndClose(errCond, errText, errAppSpec);
// see delayErrorAndClose. it's the only place we set errCond
return errorAndClose(std::get<StreamCond>(*errCond), errText, errAppSpec);
else
return error(errorCode);
}
Expand Down Expand Up @@ -834,13 +836,13 @@ void CoreProtocol::handleStreamOpen(const Parser::Event &pe)

// verify namespace
if ((!server && ns != NS_CLIENT) || (server && ns != NS_SERVER) || (dialback && db != NS_DIALBACK)) {
delayErrorAndClose(InvalidNamespace);
delayErrorAndClose(StreamCond::InvalidNamespace);
return;
}

// verify version
if (version.major < 1 && !dialback) {
delayErrorAndClose(UnsupportedVersion);
delayErrorAndClose(StreamCond::UnsupportedVersion);
return;
}
} else {
Expand Down Expand Up @@ -1474,7 +1476,7 @@ bool CoreProtocol::normalStep(const QDomElement &e)
} else if (e.tagName() == "failure") {
QDomElement t = firstChildElement(e);
if (t.isNull() || t.namespaceURI() != NS_SASL)
errCond = -1;
errCond = {};
else
errCond = stringToSASLCond(t.tagName());

Expand Down Expand Up @@ -1518,7 +1520,7 @@ bool CoreProtocol::normalStep(const QDomElement &e)
jid_ = j;
return loginComplete();
} else {
errCond = -1;
errCond = {};

QDomElement err = e.elementsByTagNameNS(NS_CLIENT, "error").item(0).toElement();
if (!err.isNull()) {
Expand All @@ -1535,9 +1537,9 @@ bool CoreProtocol::normalStep(const QDomElement &e)
if (!t.isNull() && t.namespaceURI() == NS_STANZAS) {
QString cond = t.tagName();
if (cond == "not-allowed")
errCond = BindNotAllowed;
errCond = BindCond::BindNotAllowed;
else if (cond == "conflict")
errCond = BindConflict;
errCond = BindCond::BindConflict;
}
}

Expand Down
Loading

0 comments on commit 2a9b7be

Please sign in to comment.