diff --git a/include/picotls.h b/include/picotls.h index 7a642707..218f6820 100644 --- a/include/picotls.h +++ b/include/picotls.h @@ -1588,6 +1588,7 @@ int ptls_handshake_is_complete(ptls_t *tls); int ptls_is_psk_handshake(ptls_t *tls); /** * return if a ECH handshake was performed, as well as optionally the kem and cipher-suite being used + * FIXME: this function always return false when the TLS session is exported and imported */ int ptls_is_ech_handshake(ptls_t *tls, uint8_t *config_id, ptls_hpke_kem_t **kem, ptls_hpke_cipher_suite_t **cipher); /** diff --git a/lib/picotls.c b/lib/picotls.c index c73c2dc2..2feba7ef 100644 --- a/lib/picotls.c +++ b/lib/picotls.c @@ -1615,7 +1615,7 @@ static int get_traffic_keys(ptls_aead_algorithm_t *aead, ptls_hash_algorithm_t * return ret; } -static int setup_traffic_protection(ptls_t *tls, int is_enc, const char *secret_label, size_t epoch, int skip_notify) +static int setup_traffic_protection(ptls_t *tls, int is_enc, const char *secret_label, size_t epoch, uint64_t seq, int skip_notify) { static const char *log_labels[2][4] = { {NULL, "CLIENT_EARLY_TRAFFIC_SECRET", "CLIENT_HANDSHAKE_TRAFFIC_SECRET", "CLIENT_TRAFFIC_SECRET_0"}, @@ -1645,7 +1645,7 @@ static int setup_traffic_protection(ptls_t *tls, int is_enc, const char *secret_ if ((ctx->aead = ptls_aead_new(tls->cipher_suite->aead, tls->cipher_suite->hash, is_enc, ctx->secret, tls->ctx->hkdf_label_prefix__obsolete)) == NULL) return PTLS_ERROR_NO_MEMORY; /* TODO obtain error from ptls_aead_new */ - ctx->seq = 0; + ctx->seq = seq; PTLS_DEBUGF("[%s] %02x%02x,%02x%02x\n", log_labels[ptls_is_server(tls)][epoch], (unsigned)ctx->secret[0], (unsigned)ctx->secret[1], (unsigned)ctx->aead->static_iv[0], (unsigned)ctx->aead->static_iv[1]); @@ -1664,7 +1664,7 @@ static int commission_handshake_secret(ptls_t *tls) free(tls->pending_handshake_secret); tls->pending_handshake_secret = NULL; - return setup_traffic_protection(tls, is_enc, NULL, 2, 1); + return setup_traffic_protection(tls, is_enc, NULL, 2, 0, 1); } static void log_client_random(ptls_t *tls) @@ -2479,7 +2479,7 @@ static int send_client_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptls_ if (tls->client.using_early_data) { assert(!is_second_flight); - if ((ret = setup_traffic_protection(tls, 1, "c e traffic", 1, 0)) != 0) + if ((ret = setup_traffic_protection(tls, 1, "c e traffic", 1, 0, 0)) != 0) goto Exit; if ((ret = push_change_cipher_spec(tls, emitter)) != 0) goto Exit; @@ -2795,7 +2795,7 @@ static int client_handle_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptl if ((ret = key_schedule_extract(tls->key_schedule, ecdh_secret)) != 0) goto Exit; - if ((ret = setup_traffic_protection(tls, 0, "s hs traffic", 2, 0)) != 0) + if ((ret = setup_traffic_protection(tls, 0, "s hs traffic", 2, 0, 0)) != 0) goto Exit; if (tls->client.using_early_data) { if ((tls->pending_handshake_secret = malloc(PTLS_MAX_DIGEST_SIZE)) == NULL) { @@ -2808,7 +2808,7 @@ static int client_handle_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptl (ret = tls->ctx->update_traffic_key->cb(tls->ctx->update_traffic_key, tls, 1, 2, tls->pending_handshake_secret)) != 0) goto Exit; } else { - if ((ret = setup_traffic_protection(tls, 1, "c hs traffic", 2, 0)) != 0) + if ((ret = setup_traffic_protection(tls, 1, "c hs traffic", 2, 0, 0)) != 0) goto Exit; } @@ -3373,7 +3373,7 @@ static int client_handle_finished(ptls_t *tls, ptls_message_emitter_t *emitter, /* update traffic keys by using messages upto ServerFinished, but commission them after sending ClientFinished */ if ((ret = key_schedule_extract(tls->key_schedule, ptls_iovec_init(NULL, 0))) != 0) goto Exit; - if ((ret = setup_traffic_protection(tls, 0, "s ap traffic", 3, 0)) != 0) + if ((ret = setup_traffic_protection(tls, 0, "s ap traffic", 3, 0, 0)) != 0) goto Exit; if ((ret = derive_secret(tls->key_schedule, send_secret, "c ap traffic")) != 0) goto Exit; @@ -3407,7 +3407,7 @@ static int client_handle_finished(ptls_t *tls, ptls_message_emitter_t *emitter, ret = send_finished(tls, emitter); memcpy(tls->traffic_protection.enc.secret, send_secret, sizeof(send_secret)); - if ((ret = setup_traffic_protection(tls, 1, NULL, 3, 0)) != 0) + if ((ret = setup_traffic_protection(tls, 1, NULL, 3, 0, 0)) != 0) goto Exit; tls->state = PTLS_STATE_CLIENT_POST_HANDSHAKE; @@ -4572,7 +4572,7 @@ static int server_handle_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptl } if ((ret = derive_exporter_secret(tls, 1)) != 0) goto Exit; - if ((ret = setup_traffic_protection(tls, 0, "c e traffic", 1, 0)) != 0) + if ((ret = setup_traffic_protection(tls, 0, "c e traffic", 1, 0, 0)) != 0) goto Exit; } @@ -4631,7 +4631,7 @@ static int server_handle_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptl /* create protection contexts for the handshake */ assert(tls->key_schedule->generation == 1); key_schedule_extract(tls->key_schedule, ecdh_secret); - if ((ret = setup_traffic_protection(tls, 1, "s hs traffic", 2, 0)) != 0) + if ((ret = setup_traffic_protection(tls, 1, "s hs traffic", 2, 0, 0)) != 0) goto Exit; if (tls->pending_handshake_secret != NULL) { if ((ret = derive_secret(tls->key_schedule, tls->pending_handshake_secret, "c hs traffic")) != 0) @@ -4640,7 +4640,7 @@ static int server_handle_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptl (ret = tls->ctx->update_traffic_key->cb(tls->ctx->update_traffic_key, tls, 0, 2, tls->pending_handshake_secret)) != 0) goto Exit; } else { - if ((ret = setup_traffic_protection(tls, 0, "c hs traffic", 2, 0)) != 0) + if ((ret = setup_traffic_protection(tls, 0, "c hs traffic", 2, 0, 0)) != 0) goto Exit; if (ch->psk.early_data_indication) tls->server.early_data_skipped_bytes = 0; @@ -4766,7 +4766,7 @@ static int server_finish_handshake(ptls_t *tls, ptls_message_emitter_t *emitter, assert(tls->key_schedule->generation == 2); if ((ret = key_schedule_extract(tls->key_schedule, ptls_iovec_init(NULL, 0))) != 0) goto Exit; - if ((ret = setup_traffic_protection(tls, 1, "s ap traffic", 3, 0)) != 0) + if ((ret = setup_traffic_protection(tls, 1, "s ap traffic", 3, 0, 0)) != 0) goto Exit; if ((ret = derive_secret(tls->key_schedule, tls->server.pending_traffic_secret, "c ap traffic")) != 0) goto Exit; @@ -4827,7 +4827,7 @@ static int server_handle_finished(ptls_t *tls, ptls_iovec_t message) memcpy(tls->traffic_protection.dec.secret, tls->server.pending_traffic_secret, sizeof(tls->server.pending_traffic_secret)); ptls_clear_memory(tls->server.pending_traffic_secret, sizeof(tls->server.pending_traffic_secret)); - if ((ret = setup_traffic_protection(tls, 0, NULL, 3, 0)) != 0) + if ((ret = setup_traffic_protection(tls, 0, NULL, 3, 0, 0)) != 0) return ret; ptls__key_schedule_update_hash(tls->key_schedule, message.base, message.len, 0); @@ -4847,7 +4847,7 @@ static int update_traffic_key(ptls_t *tls, int is_enc) "traffic upd", ptls_iovec_init(NULL, 0), NULL)) != 0) goto Exit; memcpy(tp->secret, secret, sizeof(secret)); - ret = setup_traffic_protection(tls, is_enc, NULL, 3, 1); + ret = setup_traffic_protection(tls, is_enc, NULL, 3, 0, 1); Exit: ptls_clear_memory(secret, sizeof(secret)); @@ -5017,6 +5017,28 @@ ptls_t *ptls_server_new(ptls_context_t *ctx) return tls; } +#define export_tls_params(output, is_server, session_reused, protocol_version, cipher, client_random, server_name, \ + negotiated_protocol, ver_block) \ + do { \ + const char *_server_name = (server_name); \ + ptls_iovec_t _negotiated_protocol = (negotiated_protocol); \ + ptls_buffer_push_block((output), 2, { \ + ptls_buffer_push((output), (is_server)); \ + ptls_buffer_push((output), (session_reused)); \ + ptls_buffer_push16((output), (protocol_version)); \ + ptls_buffer_push16((output), (cipher)->id); \ + ptls_buffer_pushv((output), (client_random), PTLS_HELLO_RANDOM_SIZE); \ + ptls_buffer_push_block((output), 2, { \ + size_t len = _server_name != NULL ? strlen(_server_name) : 0; \ + ptls_buffer_pushv((output), _server_name, len); \ + }); \ + ptls_buffer_push_block((output), 2, \ + { ptls_buffer_pushv((output), _negotiated_protocol.base, _negotiated_protocol.len); }); \ + ptls_buffer_push_block((output), 2, {ver_block}); /* version-specific block */ \ + ptls_buffer_push_block((output), 2, {}); /* for future extensions */ \ + }); \ + } while (0) + static int export_tls12_params(ptls_buffer_t *output, int is_server, int session_reused, ptls_cipher_suite_t *cipher, const void *client_random, const char *server_name, ptls_iovec_t negotiated_protocol, const void *enc_key, const void *enc_iv, uint64_t enc_seq, uint64_t enc_record_iv, @@ -5024,29 +5046,18 @@ static int export_tls12_params(ptls_buffer_t *output, int is_server, int session { int ret; - ptls_buffer_push_block(output, 2, { - ptls_buffer_push(output, is_server); - ptls_buffer_push(output, session_reused); - ptls_buffer_push16(output, PTLS_PROTOCOL_VERSION_TLS12); - ptls_buffer_push16(output, cipher->id); - ptls_buffer_pushv(output, client_random, PTLS_HELLO_RANDOM_SIZE); - ptls_buffer_push_block(output, 2, { - size_t len = server_name != NULL ? strlen(server_name) : 0; - ptls_buffer_pushv(output, server_name, len); - }); - ptls_buffer_push_block(output, 2, { ptls_buffer_pushv(output, negotiated_protocol.base, negotiated_protocol.len); }); - ptls_buffer_push_block(output, 2, { - ptls_buffer_pushv(output, enc_key, cipher->aead->key_size); - ptls_buffer_pushv(output, enc_iv, cipher->aead->tls12.fixed_iv_size); - ptls_buffer_push64(output, enc_seq); - if (cipher->aead->tls12.record_iv_size != 0) - ptls_buffer_push64(output, enc_record_iv); - ptls_buffer_pushv(output, dec_key, cipher->aead->key_size); - ptls_buffer_pushv(output, dec_iv, cipher->aead->tls12.fixed_iv_size); - ptls_buffer_push64(output, dec_seq); - }); - ptls_buffer_push_block(output, 2, {}); /* for future extensions */ - }); + export_tls_params(output, is_server, session_reused, PTLS_PROTOCOL_VERSION_TLS12, cipher, client_random, server_name, + negotiated_protocol, { + ptls_buffer_pushv(output, enc_key, cipher->aead->key_size); + ptls_buffer_pushv(output, enc_iv, cipher->aead->tls12.fixed_iv_size); + ptls_buffer_push64(output, enc_seq); + if (cipher->aead->tls12.record_iv_size != 0) + ptls_buffer_push64(output, enc_record_iv); + ptls_buffer_pushv(output, dec_key, cipher->aead->key_size); + ptls_buffer_pushv(output, dec_iv, cipher->aead->tls12.fixed_iv_size); + ptls_buffer_push64(output, dec_seq); + }); + ret = 0; Exit: return ret; @@ -5094,20 +5105,39 @@ int ptls_build_tls12_export_params(ptls_context_t *ctx, ptls_buffer_t *output, i int ptls_export(ptls_t *tls, ptls_buffer_t *output) { - /* TODO add tls13 support */ - if (!tls->traffic_protection.enc.tls12) - return PTLS_ERROR_LIBRARY; - ptls_iovec_t negotiated_protocol = ptls_iovec_init(tls->negotiated_protocol, tls->negotiated_protocol != NULL ? strlen(tls->negotiated_protocol) : 0); - return export_tls12_params(output, tls->is_server, tls->is_psk_handshake, tls->cipher_suite, tls->client_random, - tls->server_name, negotiated_protocol, tls->traffic_protection.enc.secret, - tls->traffic_protection.enc.secret + PTLS_MAX_SECRET_SIZE, tls->traffic_protection.enc.seq, - tls->traffic_protection.enc.tls12_enc_record_iv, tls->traffic_protection.dec.secret, - tls->traffic_protection.dec.secret + PTLS_MAX_SECRET_SIZE, tls->traffic_protection.dec.seq); + int ret; + + if (tls->state != PTLS_STATE_SERVER_POST_HANDSHAKE) { + ret = PTLS_ERROR_LIBRARY; + goto Exit; + } + + if (ptls_get_protocol_version(tls) == PTLS_PROTOCOL_VERSION_TLS13) { + export_tls_params(output, tls->is_server, tls->is_psk_handshake, PTLS_PROTOCOL_VERSION_TLS13, tls->cipher_suite, + tls->client_random, tls->server_name, negotiated_protocol, { + ptls_buffer_pushv(output, tls->traffic_protection.enc.secret, tls->cipher_suite->hash->digest_size); + ptls_buffer_push64(output, tls->traffic_protection.enc.seq); + ptls_buffer_pushv(output, tls->traffic_protection.dec.secret, tls->cipher_suite->hash->digest_size); + ptls_buffer_push64(output, tls->traffic_protection.dec.seq); + }); + ret = 0; + } else { + if ((ret = export_tls12_params(output, tls->is_server, tls->is_psk_handshake, tls->cipher_suite, tls->client_random, + tls->server_name, negotiated_protocol, tls->traffic_protection.enc.secret, + tls->traffic_protection.enc.secret + PTLS_MAX_SECRET_SIZE, tls->traffic_protection.enc.seq, + tls->traffic_protection.enc.tls12_enc_record_iv, tls->traffic_protection.dec.secret, + tls->traffic_protection.dec.secret + PTLS_MAX_SECRET_SIZE, + tls->traffic_protection.dec.seq)) != 0) + goto Exit; + } + +Exit: + return ret; } -static int build_tls12_traffic_protection(ptls_t *tls, int is_enc, const uint8_t **src, const uint8_t *const end) +static int import_tls12_traffic_protection(ptls_t *tls, int is_enc, const uint8_t **src, const uint8_t *const end) { struct st_ptls_traffic_protection_t *tp = is_enc ? &tls->traffic_protection.enc : &tls->traffic_protection.dec; @@ -5134,6 +5164,22 @@ static int build_tls12_traffic_protection(ptls_t *tls, int is_enc, const uint8_t return 0; } +static int import_tls13_traffic_protection(ptls_t *tls, int is_enc, const uint8_t **src, const uint8_t *const end) +{ + struct st_ptls_traffic_protection_t *tp = is_enc ? &tls->traffic_protection.enc : &tls->traffic_protection.dec; + + /* set properties */ + memcpy(tp->secret, *src, tls->cipher_suite->hash->digest_size); + *src += tls->cipher_suite->hash->digest_size; + if (ptls_decode64(&tp->seq, src, end) != 0) + return PTLS_ALERT_DECODE_ERROR; + + if (setup_traffic_protection(tls, is_enc, NULL, 3, tp->seq, 0) != 0) + return PTLS_ERROR_INCOMPATIBLE_KEY; + + return 0; +} + int ptls_import(ptls_context_t *ctx, ptls_t **tls, ptls_iovec_t params) { const uint8_t *src = params.base, *const end = src + params.len; @@ -5159,11 +5205,6 @@ int ptls_import(ptls_context_t *ctx, ptls_t **tls, ptls_iovec_t params) goto Exit; if ((ret = ptls_decode16(&csid, &src, end)) != 0) goto Exit; - (*tls)->cipher_suite = ptls_find_cipher_suite(ctx->tls12_cipher_suites, csid); - if ((*tls)->cipher_suite == NULL) { - ret = PTLS_ALERT_HANDSHAKE_FAILURE; - goto Exit; - } /* other version-independent stuff */ if (end - src < PTLS_HELLO_RANDOM_SIZE) { ret = PTLS_ALERT_DECODE_ERROR; @@ -5189,15 +5230,36 @@ int ptls_import(ptls_context_t *ctx, ptls_t **tls, ptls_iovec_t params) ptls_decode_open_block(src, end, 2, { switch (protocol_version) { case PTLS_PROTOCOL_VERSION_TLS12: + (*tls)->cipher_suite = ptls_find_cipher_suite(ctx->tls12_cipher_suites, csid); + if ((*tls)->cipher_suite == NULL) { + ret = PTLS_ALERT_HANDSHAKE_FAILURE; + goto Exit; + } /* setup AEAD keys */ - if ((ret = build_tls12_traffic_protection(*tls, 1, &src, end)) != 0) + if ((ret = import_tls12_traffic_protection(*tls, 1, &src, end)) != 0) goto Exit; - if ((ret = build_tls12_traffic_protection(*tls, 0, &src, end)) != 0) + if ((ret = import_tls12_traffic_protection(*tls, 0, &src, end)) != 0) + goto Exit; + break; + case PTLS_PROTOCOL_VERSION_TLS13: + (*tls)->cipher_suite = ptls_find_cipher_suite(ctx->cipher_suites, csid); + if ((*tls)->cipher_suite == NULL) { + ret = PTLS_ALERT_HANDSHAKE_FAILURE; + goto Exit; + } + /* setup AEAD keys */ + if (((*tls)->key_schedule = key_schedule_new((*tls)->cipher_suite, NULL, (*tls)->ech.aead != NULL)) == NULL) { + ret = PTLS_ERROR_NO_MEMORY; + goto Exit; + } + if ((ret = import_tls13_traffic_protection(*tls, 1, &src, end)) != 0) + goto Exit; + if ((ret = import_tls13_traffic_protection(*tls, 0, &src, end)) != 0) goto Exit; break; default: ret = PTLS_ALERT_ILLEGAL_PARAMETER; - break; + goto Exit; } }); /* extensions */ @@ -6232,7 +6294,6 @@ ptls_aead_context_t *new_aead(ptls_aead_algorithm_t *aead, ptls_hash_algorithm_t if ((ret = get_traffic_keys(aead, hash, key_iv.key, key_iv.iv, secret, hash_value, label_prefix)) != 0) goto Exit; ctx = ptls_aead_new_direct(aead, is_enc, key_iv.key, key_iv.iv); - Exit: ptls_clear_memory(&key_iv, sizeof(key_iv)); return ctx; diff --git a/t/picotls.c b/t/picotls.c index 01378332..1ce59258 100644 --- a/t/picotls.c +++ b/t/picotls.c @@ -864,7 +864,50 @@ static int can_ech(ptls_context_t *ctx, int is_server) } } -static void test_handshake(ptls_iovec_t ticket, int mode, int expect_ticket, int check_ch, int require_client_authentication) +static void check_clone(ptls_t *src, ptls_t *dest) +{ + ok(src->cipher_suite->hash->digest_size == dest->cipher_suite->hash->digest_size); + size_t digest_size = dest->cipher_suite->hash->digest_size; + ok(memcmp(src->traffic_protection.enc.secret, dest->traffic_protection.enc.secret, digest_size) == 0); + ok(memcmp(src->traffic_protection.dec.secret, dest->traffic_protection.dec.secret, digest_size) == 0); + const unsigned enc_idx = 0; + const unsigned dec_idx = 1; + struct { + uint8_t key[PTLS_MAX_SECRET_SIZE]; + uint8_t iv[PTLS_MAX_IV_SIZE]; + uint64_t seq; + } src_keys[2] = {0}, dest_keys[2] = {0}; + ok(ptls_get_traffic_keys(src, 1, src_keys[enc_idx].key, src_keys[enc_idx].iv, &src_keys[enc_idx].seq) == 0); + ok(ptls_get_traffic_keys(src, 0, src_keys[dec_idx].key, src_keys[dec_idx].iv, &src_keys[dec_idx].seq) == 0); + ok(ptls_get_traffic_keys(dest, 1, dest_keys[enc_idx].key, dest_keys[enc_idx].iv, &dest_keys[enc_idx].seq) == 0); + ok(ptls_get_traffic_keys(dest, 0, dest_keys[dec_idx].key, dest_keys[dec_idx].iv, &dest_keys[dec_idx].seq) == 0); + ok(src_keys[enc_idx].seq == dest_keys[enc_idx].seq); + ok(src_keys[dec_idx].seq == dest_keys[dec_idx].seq); + ok(memcmp(src_keys[enc_idx].key, dest_keys[enc_idx].key, PTLS_MAX_SECRET_SIZE) == 0); + ok(memcmp(src_keys[dec_idx].key, dest_keys[dec_idx].key, PTLS_MAX_SECRET_SIZE) == 0); + ok(memcmp(src_keys[enc_idx].iv, dest_keys[enc_idx].iv, PTLS_MAX_IV_SIZE) == 0); + ok(memcmp(src_keys[dec_idx].iv, dest_keys[dec_idx].iv, PTLS_MAX_IV_SIZE) == 0); +} + +static ptls_t *clone_tls(ptls_t *src) +{ + ptls_t *dest = NULL; + ptls_buffer_t sess_data; + + ptls_buffer_init(&sess_data, "", 0); + int r = ptls_export(src, &sess_data); + assert(r == 0); + r = ptls_import(ctx_peer, &dest, (ptls_iovec_t){.base = sess_data.base, .len = sess_data.off}); + assert(r == 0); + ptls_buffer_dispose(&sess_data); + + check_clone(src, dest); + + return dest; +} + +static void test_handshake(ptls_iovec_t ticket, int mode, int expect_ticket, int check_ch, int require_client_authentication, + int transfer_session) { ptls_t *client, *server; ptls_handshake_properties_t client_hs_prop = {{{{NULL}, ticket}}}, server_hs_prop = {{{{NULL}}}}; @@ -1055,6 +1098,9 @@ static void test_handshake(ptls_iovec_t ticket, int mode, int expect_ticket, int cbuf.off = 0; } + /* holds the ptls_t pointer of server prior to migration */ + ptls_t *original_server = server; + if (mode != TEST_HANDSHAKE_EARLY_DATA || require_client_authentication) { ret = ptls_send(client, &cbuf, req, strlen(req)); ok(ret == 0); @@ -1068,6 +1114,8 @@ static void test_handshake(ptls_iovec_t ticket, int mode, int expect_ticket, int ok(ptls_handshake_is_complete(server)); decbuf.off = 0; cbuf.off = 0; + if (transfer_session) + server = clone_tls(original_server); ret = ptls_send(server, &sbuf, resp, strlen(resp)); ok(ret == 0); @@ -1125,18 +1173,21 @@ static void test_handshake(ptls_iovec_t ticket, int mode, int expect_ticket, int decbuf.off = 0; } + /* original_server is used for the server-side checks because handshake data is never migrated */ if (can_ech(ctx_peer, 1) && can_ech(ctx, 0)) { ok(ptls_is_ech_handshake(client, NULL, NULL, NULL)); - ok(ptls_is_ech_handshake(server, NULL, NULL, NULL)); + ok(ptls_is_ech_handshake(original_server, NULL, NULL, NULL)); } else { ok(!ptls_is_ech_handshake(client, NULL, NULL, NULL)); - ok(!ptls_is_ech_handshake(server, NULL, NULL, NULL)); + ok(!ptls_is_ech_handshake(original_server, NULL, NULL, NULL)); } ptls_buffer_dispose(&cbuf); ptls_buffer_dispose(&sbuf); ptls_buffer_dispose(&decbuf); ptls_free(client); + if (original_server != server) + ptls_free(original_server); ptls_free(server); if (check_ch) @@ -1202,19 +1253,19 @@ static int second_sign_certificate(ptls_sign_certificate_t *self, ptls_t *tls, p return second_sc_orig->cb(second_sc_orig, tls, async, selected_algorithm, output, input, algorithms, num_algorithms); } -static void test_full_handshake_impl(int require_client_authentication, int is_async) +static void test_full_handshake_impl(int require_client_authentication, int is_async, int transfer_session) { - test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_1RTT, 0, 0, require_client_authentication); + test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_1RTT, 0, 0, require_client_authentication, transfer_session); ok(server_sc_callcnt == 1); ok(async_sc_callcnt == is_async); ok(client_sc_callcnt == require_client_authentication); - test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_1RTT, 0, 0, require_client_authentication); + test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_1RTT, 0, 0, require_client_authentication, transfer_session); ok(server_sc_callcnt == 1); ok(async_sc_callcnt == is_async); ok(client_sc_callcnt == require_client_authentication); - test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_1RTT, 0, 1, require_client_authentication); + test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_1RTT, 0, 1, require_client_authentication, transfer_session); ok(server_sc_callcnt == 1); ok(async_sc_callcnt == is_async); ok(client_sc_callcnt == require_client_authentication); @@ -1222,28 +1273,32 @@ static void test_full_handshake_impl(int require_client_authentication, int is_a static void test_full_handshake(void) { - test_full_handshake_impl(0, 0); + test_full_handshake_impl(0, 0, 0); + test_full_handshake_impl(0, 0, 0); } static void test_full_handshake_with_client_authentication(void) { - test_full_handshake_impl(1, 0); + test_full_handshake_impl(1, 0, 0); + test_full_handshake_impl(1, 0, 1); } static void test_key_update(void) { - test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_KEY_UPDATE, 0, 0, 0); + test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_KEY_UPDATE, 0, 0, 0, 0); + test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_KEY_UPDATE, 0, 0, 0, 1); } static void test_hrr_handshake(void) { - test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_HRR, 0, 0, 0); + test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_HRR, 0, 0, 0, 0); ok(server_sc_callcnt == 1); + test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_HRR, 0, 0, 0, 0); } static void test_hrr_stateless_handshake(void) { - test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_HRR_STATELESS, 0, 0, 0); + test_handshake(ptls_iovec_init(NULL, 0), TEST_HANDSHAKE_HRR_STATELESS, 0, 0, 0, 0); ok(server_sc_callcnt == 1); } @@ -1269,7 +1324,7 @@ static int on_save_ticket(ptls_save_ticket_t *self, ptls_t *tls, ptls_iovec_t sr return 0; } -static void test_resumption_impl(int different_preferred_key_share, int require_client_authentication) +static void test_resumption_impl(int different_preferred_key_share, int require_client_authentication, int transfer_session) { assert(ctx->key_exchanges[0]->id == ctx_peer->key_exchanges[0]->id); assert(ctx->key_exchanges[1] == NULL); @@ -1295,29 +1350,30 @@ static void test_resumption_impl(int different_preferred_key_share, int require_ ctx_peer->encrypt_ticket = &et; ctx->save_ticket = &st; - test_handshake(saved_ticket, different_preferred_key_share ? TEST_HANDSHAKE_2RTT : TEST_HANDSHAKE_1RTT, 1, 0, 0); + test_handshake(saved_ticket, different_preferred_key_share ? TEST_HANDSHAKE_2RTT : TEST_HANDSHAKE_1RTT, 1, 0, 0, + transfer_session); ok(server_sc_callcnt == 1); ok(saved_ticket.base != NULL); /* psk using saved ticket */ - test_handshake(saved_ticket, TEST_HANDSHAKE_1RTT, 1, 0, require_client_authentication); + test_handshake(saved_ticket, TEST_HANDSHAKE_1RTT, 1, 0, require_client_authentication, transfer_session); ok(server_sc_callcnt == require_client_authentication); /* client authentication turns off resumption */ ok(client_sc_callcnt == require_client_authentication); /* 0-rtt psk using saved ticket */ - test_handshake(saved_ticket, TEST_HANDSHAKE_EARLY_DATA, 1, 0, require_client_authentication); + test_handshake(saved_ticket, TEST_HANDSHAKE_EARLY_DATA, 1, 0, require_client_authentication, transfer_session); ok(server_sc_callcnt == require_client_authentication); /* client authentication turns off resumption */ ok(client_sc_callcnt == require_client_authentication); ctx->require_dhe_on_psk = 1; /* psk-dhe using saved ticket */ - test_handshake(saved_ticket, TEST_HANDSHAKE_1RTT, 1, 0, require_client_authentication); + test_handshake(saved_ticket, TEST_HANDSHAKE_1RTT, 1, 0, require_client_authentication, transfer_session); ok(server_sc_callcnt == require_client_authentication); /* client authentication turns off resumption */ ok(client_sc_callcnt == require_client_authentication); /* 0-rtt psk-dhe using saved ticket */ - test_handshake(saved_ticket, TEST_HANDSHAKE_EARLY_DATA, 1, 0, require_client_authentication); + test_handshake(saved_ticket, TEST_HANDSHAKE_EARLY_DATA, 1, 0, require_client_authentication, transfer_session); ok(server_sc_callcnt == require_client_authentication); /* client authentication turns off resumption */ ok(client_sc_callcnt == require_client_authentication); @@ -1331,19 +1387,22 @@ static void test_resumption_impl(int different_preferred_key_share, int require_ static void test_resumption(void) { - test_resumption_impl(0, 0); + test_resumption_impl(0, 0, 0); + test_resumption_impl(0, 0, 1); } static void test_resumption_different_preferred_key_share(void) { if (ctx == ctx_peer) return; - test_resumption_impl(1, 0); + test_resumption_impl(1, 0, 0); + test_resumption_impl(0, 0, 1); } static void test_resumption_with_client_authentication(void) { - test_resumption_impl(0, 1); + test_resumption_impl(0, 0, 0); + test_resumption_impl(0, 1, 1); } static void test_async_sign_certificate(void) @@ -1353,7 +1412,7 @@ static void test_async_sign_certificate(void) ptls_sign_certificate_t async_sc = {async_sign_certificate}, *orig_sc = ctx_peer->sign_certificate; ctx_peer->sign_certificate = &async_sc; - test_full_handshake_impl(0, 1); + test_full_handshake_impl(0, 1, 0); ctx_peer->sign_certificate = orig_sc; }