Skip to content

Commit

Permalink
Report TLS parsing error (#28)
Browse files Browse the repository at this point in the history
Task/Issue URL: https://app.asana.com/0/488551667048375/1205766744020104/f

### Description
Report TLS parsing errors back to JVM land

### Steps to test this PR
- `cd src/test`
- `make`
- `./test_tls` should pass
  • Loading branch information
aitorvs authored Oct 20, 2023
1 parent aadc544 commit 03d1c2c
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 28 deletions.
2 changes: 2 additions & 0 deletions src/netguard/include/netguard.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ void log_packet(const struct arguments *args, const packet_t *packet);
void dns_resolved(const struct arguments *args,
const char *qname, const char *aname, const char *resource, int ttl);

void report_tls_parsing_error(const struct arguments *args, jint error_code);

jboolean is_domain_blocked(const struct arguments *args, const char *name, jint uid);

jint get_uid_q(const struct arguments *args,
Expand Down
31 changes: 31 additions & 0 deletions src/netguard/netguard.c
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,37 @@ void dns_resolved(const struct arguments *args,
#endif
}

static jmethodID midReportTLSParsingError = NULL;

void report_tls_parsing_error(const struct arguments *args, jint error_code) {
#ifdef PROFILE_JNI
float mselapsed;
struct timeval start, end;
gettimeofday(&start, NULL);
#endif

jclass clsService = (*args->env)->GetObjectClass(args->env, args->instance);
ng_add_alloc(clsService, "clsService");

const char *signature = "(I)V";
if (midReportTLSParsingError == NULL)
midReportTLSParsingError = jniGetMethodID(args->env, clsService, "reportTLSParsingError", signature);

(*args->env)->CallVoidMethod(args->env, args->instance, midReportTLSParsingError, error_code);
jniCheckException(args->env);

(*args->env)->DeleteLocalRef(args->env, clsService);
ng_delete_alloc(clsService, __FILE__, __LINE__);

#ifdef PROFILE_JNI
gettimeofday(&end, NULL);
mselapsed = (end.tv_sec - start.tv_sec) * 1000.0 +
(end.tv_usec - start.tv_usec) / 1000.0;
if (mselapsed > PROFILE_JNI)
log_print(PLATFORM_LOG_PRIORITY_WARN, "is_domain_blocked %f", mselapsed);
#endif
}

static jmethodID midIsDomainBlocked = NULL;

jboolean is_domain_blocked(const struct arguments *args, const char *name, jint uid) {
Expand Down
5 changes: 4 additions & 1 deletion src/netguard/tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ int is_sni_found_and_blocked(
memset(sn, 0, FQDN_LENGTH);
*sn = 0;

get_server_name(pkt, length, tls, sn);
int error_code = get_server_name(pkt, length, tls, sn);

if (error_code < 0) {
report_tls_parsing_error(args, error_code);
}
if (strlen(sn) == 0) {
log_print(PLATFORM_LOG_PRIORITY_INFO, "TLS server name not found");
return 0;
Expand Down
40 changes: 16 additions & 24 deletions src/netguard/tls_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ static int parse_tls_server_name(const uint8_t *data, size_t data_len, char *ser
uint8_t tls_version_minor = data[2];
if (tls_version_major < 3) {
// receive handshake that can't support SNI
return -2;
return -4;
}

/* TLS record length */
Expand All @@ -79,12 +79,12 @@ static int parse_tls_server_name(const uint8_t *data, size_t data_len, char *ser
/* handshake */
size_t pos = TLS_HEADER_LEN;
if (pos + 1 > data_len) {
return -4;
return -5;
}

if (data[pos] != 0x1) {
// not a client hello
return -4;
return -6;
}

/* Skip past fixed length records:
Expand All @@ -97,34 +97,34 @@ static int parse_tls_server_name(const uint8_t *data, size_t data_len, char *ser
pos += 38;

// Session ID
if (pos + 1 > data_len) return -4;
if (pos + 1 > data_len) return -7;
len = (size_t)data[pos];
pos += 1 + len;

/* Cipher Suites */
if (pos + 2 > data_len) return -4;
if (pos + 2 > data_len) return -8;
len = ((size_t)data[pos] << 8) + (size_t)data[pos + 1];
pos += 2 + len;

/* Compression Methods */
if (pos + 1 > data_len) return -4;
if (pos + 1 > data_len) return -9;
len = (size_t)data[pos];
pos += 1 + len;

if (pos == data_len && tls_version_major == 3 && tls_version_minor == 0) {
// "Received SSL 3.0 handshake without extensions"
return -2;
return -10;
}

/* Extensions */
if (pos + 2 > data_len) {
return -4;
return -11;
}
len = ((size_t)data[pos] << 8) + (size_t)data[pos + 1];
pos += 2;

if (pos + len > data_len) {
return -4;
return -12;
}
return parse_extensions(data + pos, len, server_name);
}
Expand All @@ -144,16 +144,16 @@ static int parse_extensions(const uint8_t *data, size_t data_len, char *hostname
/* There can be only one extension of each type, so we break
our state and move p to beinnging of the extension here */
if (pos + 4 + len > data_len)
return -5;
return -20;
return parse_server_name_extension(data + pos + 4, len, hostname);
}
pos += 4 + len; /* Advance to the next extension header */
}
/* Check we ended where we expected to */
if (pos != data_len)
return -5;
return -21;

return -2;
return -22;
}

static int parse_server_name_extension(const uint8_t *data, size_t data_len, char *hostname) {
Expand All @@ -165,7 +165,7 @@ static int parse_server_name_extension(const uint8_t *data, size_t data_len, cha
(size_t)data[pos + 2];

if (pos + 3 + len > data_len) {
return -4;
return -30;
}

switch (data[pos]) { /* name type */
Expand All @@ -180,10 +180,10 @@ static int parse_server_name_extension(const uint8_t *data, size_t data_len, cha
}
/* Check we ended where we expected to */
if (pos != data_len) {
return -4;
return -31;
}

return -2;
return -32;
}

int get_server_name(
Expand All @@ -196,16 +196,8 @@ int get_server_name(
int error_code = parse_tls_server_name(tls, data_len, server_name);
if (error_code >= 0) {
log_print(PLATFORM_LOG_PRIORITY_DEBUG, "Found server name %s", server_name);
} else if (error_code == -1) {
log_print(PLATFORM_LOG_PRIORITY_DEBUG, "Incomplete TLs request");
} else if (error_code == -2) {
log_print(PLATFORM_LOG_PRIORITY_DEBUG, "No SNI header found");
} else if (error_code == -3) {
log_print(PLATFORM_LOG_PRIORITY_DEBUG, "invalid TLS client hello");
} else if (error_code == -4) {
log_print(PLATFORM_LOG_PRIORITY_DEBUG, "invalid TLS packet");
} else {
log_print(PLATFORM_LOG_PRIORITY_DEBUG, "Unknown error");
log_print(PLATFORM_LOG_PRIORITY_DEBUG, "TLS parsing error code %d", error_code);
}

return error_code;
Expand Down
6 changes: 3 additions & 3 deletions src/test/test_tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ int main() {
error = get_server_name(pkt, sizeof(ssl30_request), pkt, sn);
assert(strcmp("localhost", sn) != 0);
assert(strlen(sn) == 0);
assert(error == -2);
assert(error == -10);

pkt = (uint8_t *)ssl20_client_hello;
memset(sn, 0, FQDN_LENGTH);
Expand All @@ -471,15 +471,15 @@ int main() {
error = get_server_name(pkt, sizeof(bad_data_1), pkt, sn);
assert(strcmp("localhost", sn) != 0);
assert(strlen(sn) == 0);
assert(error == -4);
assert(error == -12);

pkt = (uint8_t *)bad_data_2;
memset(sn, 0, FQDN_LENGTH);
*sn = 0;
error = get_server_name(pkt, sizeof(bad_data_2), pkt, sn);
assert(strcmp("localhost", sn) != 0);
assert(strlen(sn) == 0);
assert(error == -4);
assert(error == -12);

pkt = (uint8_t *)bad_data_3;
memset(sn, 0, FQDN_LENGTH);
Expand Down

0 comments on commit 03d1c2c

Please sign in to comment.