components/openssl: optimize the SSL certification and private key function

1. add inheritance function
2. remove low-level platform unload cert & pkey function
3. optimize the cert load and free function
This commit is contained in:
Dong Heng 2016-09-26 11:14:19 +08:00
parent e1c4a4bfa3
commit cf4aaf6397
12 changed files with 178 additions and 164 deletions

View File

@ -21,6 +21,15 @@
#include "ssl_types.h"
/**
* @brief create a certification object include private key object according to input certification
*
* @param ic - input certification point
*
* @return certification object point
*/
CERT *__ssl_cert_new(CERT *ic);
/**
* @brief create a certification object include private key object
*

View File

@ -69,14 +69,12 @@
#define IMPLEMENT_X509_METHOD(func_name, \
new, \
free, \
load, \
unload) \
load) \
const X509_METHOD* func_name(void) { \
static const X509_METHOD func_name##_data LOCAL_ATRR = { \
new, \
free, \
load, \
unload, \
load \
}; \
return &func_name##_data; \
}
@ -84,14 +82,12 @@
#define IMPLEMENT_PKEY_METHOD(func_name, \
new, \
free, \
load, \
unload) \
load) \
const PKEY_METHOD* func_name(void) { \
static const PKEY_METHOD func_name##_data LOCAL_ATRR = { \
new, \
free, \
load, \
unload, \
load \
}; \
return &func_name##_data; \
}

View File

@ -21,6 +21,15 @@
#include "ssl_types.h"
/**
* @brief create a private key object according to input private key
*
* @param ipk - input private key point
*
* @return new private key object point
*/
EVP_PKEY* __EVP_PKEY_new(EVP_PKEY *ipk);
/**
* @brief create a private key object
*

View File

@ -196,12 +196,8 @@ struct ssl_st
/* shut things down(0x01 : sent, 0x02 : received) */
int shutdown;
int crt_reload;
CERT *cert;
int ca_reload;
X509 *client_CA;
SSL_CTX *ctx;
@ -274,24 +270,20 @@ struct ssl_method_func_st {
struct x509_method_st {
int (*x509_new)(X509 *x);
int (*x509_new)(X509 *x, X509 *m_x);
void (*x509_free)(X509 *x);
int (*x509_load)(X509 *x, const unsigned char *buf, int len);
void (*x509_unload)(X509 *x);
};
struct pkey_method_st {
int (*pkey_new)(EVP_PKEY *pkey);
int (*pkey_new)(EVP_PKEY *pkey, EVP_PKEY *m_pkey);
void (*pkey_free)(EVP_PKEY *pkey);
int (*pkey_load)(EVP_PKEY *pkey, const unsigned char *buf, int len);
void (*pkey_unload)(EVP_PKEY *pkey);
};
typedef int (*next_proto_cb)(SSL *ssl, unsigned char **out,

View File

@ -24,6 +24,15 @@
DEFINE_STACK_OF(X509_NAME)
/**
* @brief create a X509 certification object according to input X509 certification
*
* @param ix - input X509 certification point
*
* @return new X509 certification object point
*/
X509* __X509_new(X509 *ix);
/**
* @brief create a X509 certification object
*

View File

@ -42,16 +42,13 @@ OSSL_HANDSHAKE_STATE ssl_pm_get_state(const SSL *ssl);
void ssl_pm_set_bufflen(SSL *ssl, int len);
int x509_pm_new(X509 *x);
int x509_pm_new(X509 *x, X509 *m_x);
void x509_pm_free(X509 *x);
int x509_pm_load(X509 *x, const unsigned char *buffer, int len);
void x509_pm_unload(X509 *x);
void x509_pm_start_ca(X509 *x);
int pkey_pm_new(EVP_PKEY *pkey);
void pkey_pm_free(EVP_PKEY *pkey);
int pkey_pm_load(EVP_PKEY *pkey, const unsigned char *buffer, int len);
void pkey_pm_unload(EVP_PKEY *pkey);
int pkey_pm_new(EVP_PKEY *pk, EVP_PKEY *m_pk);
void pkey_pm_free(EVP_PKEY *pk);
int pkey_pm_load(EVP_PKEY *pk, const unsigned char *buffer, int len);
long ssl_pm_get_verify_result(const SSL *ssl);

View File

@ -19,23 +19,34 @@
#include "ssl_port.h"
/**
* @brief create a certification object include private key object
* @brief create a certification object according to input certification
*/
CERT *ssl_cert_new(void)
CERT *__ssl_cert_new(CERT *ic)
{
CERT *cert;
X509 *ix;
EVP_PKEY *ipk;
cert = ssl_zalloc(sizeof(CERT));
if (!cert)
SSL_RET(failed1, "ssl_zalloc\n");
cert->pkey = EVP_PKEY_new();
if (!cert->pkey)
SSL_RET(failed2, "EVP_PKEY_new\n");
if (ic) {
ipk = ic->pkey;
ix = ic->x509;
} else {
ipk = NULL;
ix = NULL;
}
cert->x509 = X509_new();
cert->pkey = __EVP_PKEY_new(ipk);
if (!cert->pkey)
SSL_RET(failed2, "__EVP_PKEY_new\n");
cert->x509 = __X509_new(ix);
if (!cert->x509)
SSL_RET(failed3, "X509_new\n");
SSL_RET(failed3, "__X509_new\n");
return cert;
@ -47,6 +58,14 @@ failed1:
return NULL;
}
/**
* @brief create a certification object include private key object
*/
CERT *ssl_cert_new(void)
{
return __ssl_cert_new(NULL);
}
/**
* @brief free a certification object
*/

View File

@ -158,11 +158,11 @@ SSL_CTX* SSL_CTX_new(const SSL_METHOD *method)
CERT *cert;
X509 *client_ca;
if (!method) SSL_RET(go_failed1, "method\n");
if (!method) SSL_RET(go_failed1, "method:NULL\n");
client_ca = X509_new();
if (!client_ca)
SSL_RET(go_failed1, "sk_X509_NAME_new_null\n");
SSL_RET(go_failed1, "X509_new\n");
cert = ssl_cert_new();
if (!cert)
@ -170,7 +170,7 @@ SSL_CTX* SSL_CTX_new(const SSL_METHOD *method)
ctx = (SSL_CTX *)ssl_zalloc(sizeof(SSL_CTX));
if (!ctx)
SSL_RET(go_failed3, "ssl_ctx_new:ctx\n");
SSL_RET(go_failed3, "ssl_zalloc:ctx\n");
ctx->method = method;
ctx->client_CA = client_ca;
@ -244,15 +244,15 @@ SSL *SSL_new(SSL_CTX *ctx)
ssl->session = SSL_SESSION_new();
if (!ssl->session)
SSL_RET(failed2, "ssl_zalloc\n");
SSL_RET(failed2, "SSL_SESSION_new\n");
ssl->cert = ssl_cert_new();
ssl->cert = __ssl_cert_new(ctx->cert);
if (!ssl->cert)
SSL_RET(failed3, "ssl_cert_new\n");
SSL_RET(failed3, "__ssl_cert_new\n");
ssl->client_CA = X509_new();
ssl->client_CA = __X509_new(ctx->client_CA);
if (!ssl->client_CA)
SSL_RET(failed4, "ssl_cert_new\n");
SSL_RET(failed4, "__X509_new\n");
ssl->ctx = ctx;
ssl->method = ctx->method;

View File

@ -72,11 +72,11 @@ IMPLEMENT_SSL_METHOD(SSL3_VERSION, -1, TLS_method_func, SSLv3_method);
*/
IMPLEMENT_X509_METHOD(X509_method,
x509_pm_new, x509_pm_free,
x509_pm_load, x509_pm_unload);
x509_pm_load);
/**
* @brief get private key object method
*/
IMPLEMENT_PKEY_METHOD(EVP_PKEY_method,
pkey_pm_new, pkey_pm_free,
pkey_pm_load, pkey_pm_unload);
pkey_pm_load);

View File

@ -20,20 +20,24 @@
#include "ssl_port.h"
/**
* @brief create a private key object
* @brief create a private key object according to input private key
*/
EVP_PKEY* EVP_PKEY_new(void)
EVP_PKEY* __EVP_PKEY_new(EVP_PKEY *ipk)
{
int ret;
EVP_PKEY *pkey;
pkey = ssl_zalloc(sizeof(EVP_PKEY));
if (!pkey)
SSL_RET(failed1, "ssl_malloc\n");
SSL_RET(failed1, "ssl_zalloc\n");
pkey->method = EVP_PKEY_method();
if (ipk) {
pkey->method = ipk->method;
} else {
pkey->method = EVP_PKEY_method();
}
ret = EVP_PKEY_METHOD_CALL(new, pkey);
ret = EVP_PKEY_METHOD_CALL(new, pkey, ipk);
if (ret)
SSL_RET(failed2, "EVP_PKEY_METHOD_CALL\n");
@ -45,6 +49,14 @@ failed1:
return NULL;
}
/**
* @brief create a private key object
*/
EVP_PKEY* EVP_PKEY_new(void)
{
return __EVP_PKEY_new(NULL);
}
/**
* @brief free a private key object
*/
@ -105,6 +117,9 @@ int SSL_CTX_use_PrivateKey(SSL_CTX *ctx, EVP_PKEY *pkey)
SSL_ASSERT(ctx);
SSL_ASSERT(pkey);
if (ctx->cert->pkey == pkey)
return 1;
if (ctx->cert->pkey)
EVP_PKEY_free(ctx->cert->pkey);
@ -118,12 +133,13 @@ int SSL_CTX_use_PrivateKey(SSL_CTX *ctx, EVP_PKEY *pkey)
*/
int SSL_use_PrivateKey(SSL *ssl, EVP_PKEY *pkey)
{
SSL_ASSERT(ctx);
SSL_ASSERT(ssl);
SSL_ASSERT(pkey);
if (!ssl->ca_reload)
ssl->ca_reload = 1;
else
if (ssl->cert->pkey == pkey)
return 1;
if (ssl->cert->pkey)
EVP_PKEY_free(ssl->cert->pkey);
ssl->cert->pkey = pkey;
@ -138,20 +154,20 @@ int SSL_CTX_use_PrivateKey_ASN1(int type, SSL_CTX *ctx,
const unsigned char *d, long len)
{
int ret;
EVP_PKEY *pkey;
EVP_PKEY *pk;
pkey = d2i_PrivateKey(0, &ctx->cert->pkey, &d, len);
if (!pkey)
pk = d2i_PrivateKey(0, NULL, &d, len);
if (!pk)
SSL_RET(failed1, "d2i_PrivateKey\n");
ret = SSL_CTX_use_PrivateKey(ctx, pkey);
ret = SSL_CTX_use_PrivateKey(ctx, pk);
if (!ret)
SSL_RET(failed2, "SSL_CTX_use_PrivateKey\n");
return 1;
failed2:
EVP_PKEY_free(pkey);
EVP_PKEY_free(pk);
failed1:
return 0;
}
@ -163,44 +179,20 @@ int SSL_use_PrivateKey_ASN1(int type, SSL *ssl,
const unsigned char *d, long len)
{
int ret;
int reload;
EVP_PKEY *pkey;
CERT *cert;
CERT *old_cert;
EVP_PKEY *pk;
if (!ssl->crt_reload) {
cert = ssl_cert_new();
if (!cert)
SSL_RET(failed1, "ssl_cert_new\n");
pk = d2i_PrivateKey(0, NULL, &d, len);
if (!pk)
SSL_RET(failed1, "d2i_PrivateKey\n");
old_cert = ssl->cert ;
ssl->cert = cert;
ssl->crt_reload = 1;
reload = 1;
} else {
reload = 0;
}
pkey = d2i_PrivateKey(0, &ssl->cert->pkey, &d, len);
if (!pkey)
SSL_RET(failed2, "d2i_PrivateKey\n");
ret = SSL_use_PrivateKey(ssl, pkey);
ret = SSL_use_PrivateKey(ssl, pk);
if (!ret)
SSL_RET(failed3, "SSL_use_PrivateKey\n");
SSL_RET(failed2, "SSL_use_PrivateKey\n");
return 1;
failed3:
EVP_PKEY_free(pkey);
failed2:
if (reload) {
ssl->cert = old_cert;
ssl_cert_free(cert);
ssl->crt_reload = 0;
}
EVP_PKEY_free(pk);
failed1:
return 0;
}

View File

@ -19,9 +19,9 @@
#include "ssl_port.h"
/**
* @brief create a X509 certification object
* @brief create a X509 certification object according to input X509 certification
*/
X509* X509_new(void)
X509* __X509_new(X509 *ix)
{
int ret;
X509 *x;
@ -30,9 +30,12 @@ X509* X509_new(void)
if (!x)
SSL_RET(failed1, "ssl_malloc\n");
x->method = X509_method();
if (ix)
x->method = ix->method;
else
x->method = X509_method();
ret = X509_METHOD_CALL(new, x);
ret = X509_METHOD_CALL(new, x, ix);
if (ret)
SSL_RET(failed2, "x509_new\n");
@ -44,6 +47,14 @@ failed1:
return NULL;
}
/**
* @brief create a X509 certification object
*/
X509* X509_new(void)
{
return __X509_new(NULL);
}
/**
* @brief free a X509 certification object
*/
@ -78,7 +89,7 @@ X509* d2i_X509(X509 **cert, const unsigned char *buffer, long len)
ret = X509_METHOD_CALL(load, x, buffer, len);
if (ret)
SSL_RET(failed2, "X509_METHOD_CALL\n");
SSL_RET(failed2, "x509_load\n");
return x;
@ -97,8 +108,10 @@ int SSL_CTX_add_client_CA(SSL_CTX *ctx, X509 *x)
SSL_ASSERT(ctx);
SSL_ASSERT(x);
if (ctx->client_CA)
X509_free(ctx->client_CA);
if (ctx->client_CA == x)
return 1;
X509_free(ctx->client_CA);
ctx->client_CA = x;
@ -113,10 +126,10 @@ int SSL_add_client_CA(SSL *ssl, X509 *x)
SSL_ASSERT(ssl);
SSL_ASSERT(x);
if (!ssl->ca_reload)
ssl->ca_reload = 1;
else
X509_free(ssl->client_CA);
if (ssl->client_CA == x)
return 1;
X509_free(ssl->client_CA);
ssl->client_CA = x;
@ -131,6 +144,11 @@ int SSL_CTX_use_certificate(SSL_CTX *ctx, X509 *x)
SSL_ASSERT(ctx);
SSL_ASSERT(x);
if (ctx->cert->x509 == x)
return 1;
X509_free(ctx->cert->x509);
ctx->cert->x509 = x;
return 1;
@ -141,9 +159,14 @@ int SSL_CTX_use_certificate(SSL_CTX *ctx, X509 *x)
*/
int SSL_use_certificate(SSL *ssl, X509 *x)
{
SSL_ASSERT(ctx);
SSL_ASSERT(ssl);
SSL_ASSERT(x);
if (ssl->cert->x509 == x)
return 1;
X509_free(ssl->cert->x509);
ssl->cert->x509 = x;
return 1;
@ -166,20 +189,20 @@ int SSL_CTX_use_certificate_ASN1(SSL_CTX *ctx, int len,
const unsigned char *d)
{
int ret;
X509 *cert;
X509 *x;
cert = d2i_X509(&ctx->cert->x509, d, len);
if (!cert)
x = d2i_X509(NULL, d, len);
if (!x)
SSL_RET(failed1, "d2i_X509\n");
ret = SSL_CTX_use_certificate(ctx, cert);
ret = SSL_CTX_use_certificate(ctx, x);
if (!ret)
SSL_RET(failed2, "SSL_CTX_use_certificate\n");
return 1;
failed2:
X509_free(cert);
X509_free(x);
failed1:
return 0;
}
@ -193,42 +216,20 @@ int SSL_use_certificate_ASN1(SSL *ssl, int len,
int ret;
int reload;
X509 *x;
CERT *cert;
CERT *old_cert;
int m = 0;
if (!ssl->crt_reload) {
cert = ssl_cert_new();
if (!cert)
SSL_RET(failed1, "ssl_cert_new\n");
old_cert = ssl->cert ;
ssl->cert = cert;
ssl->crt_reload = 1;
reload = 1;
} else {
reload = 0;
}
x = d2i_X509(&ssl->cert->x509, d, len);
x = d2i_X509(NULL, d, len);
if (!x)
SSL_RET(failed2, "d2i_X509\n");
SSL_RET(failed1, "d2i_X509\n");
ret = SSL_use_certificate(ssl, x);
if (!ret)
SSL_RET(failed3, "SSL_use_certificate\n");
SSL_RET(failed2, "SSL_use_certificate\n");
return 1;
failed3:
X509_free(x);
failed2:
if (reload) {
ssl->cert = old_cert;
ssl_cert_free(cert);
ssl->crt_reload = 0;
}
X509_free(x);
failed1:
return 0;
}

View File

@ -78,14 +78,6 @@ int ssl_pm_new(SSL *ssl)
const SSL_METHOD *method = ssl->method;
struct x509_pm *ctx_ca = (struct x509_pm *)ssl->ctx->client_CA->x509_pm;
struct x509_pm *ctx_crt = (struct x509_pm *)ssl->ctx->cert->x509->x509_pm;
struct pkey_pm *ctx_pkey = (struct pkey_pm *)ssl->ctx->cert->pkey->pkey_pm;
struct x509_pm *ssl_ca = (struct x509_pm *)ssl->client_CA->x509_pm;
struct x509_pm *ssl_crt = (struct x509_pm *)ssl->cert->x509->x509_pm;
struct pkey_pm *ssl_pkey = (struct pkey_pm *)ssl->cert->pkey->pkey_pm;
ssl_pm = ssl_zalloc(sizeof(struct ssl_pm));
if (!ssl_pm)
SSL_ERR(ret, failed1, "ssl_zalloc\n");
@ -134,10 +126,6 @@ int ssl_pm_new(SSL *ssl)
ssl->ssl_pm = ssl_pm;
ssl_ca->ex_crt = ctx_ca->x509_crt;
ssl_crt->ex_crt = ctx_crt->x509_crt;
ssl_pkey->ex_pkey = ctx_pkey->pkey;
return 0;
failed3:
@ -376,7 +364,7 @@ OSSL_HANDSHAKE_STATE ssl_pm_get_state(const SSL *ssl)
return state;
}
int x509_pm_new(X509 *x)
int x509_pm_new(X509 *x, X509 *m_x)
{
struct x509_pm *x509_pm;
@ -386,13 +374,19 @@ int x509_pm_new(X509 *x)
x->x509_pm = x509_pm;
if (m_x) {
struct x509_pm *m_x509_pm = (struct x509_pm *)m_x->x509_pm;
x509_pm->ex_crt = m_x509_pm->x509_crt;
}
return 0;
failed1:
return -1;
}
void x509_pm_unload(X509 *x)
void x509_pm_free(X509 *x)
{
struct x509_pm *x509_pm = (struct x509_pm *)x->x509_pm;
@ -402,11 +396,6 @@ void x509_pm_unload(X509 *x)
ssl_free(x509_pm->x509_crt);
x509_pm->x509_crt = NULL;
}
}
void x509_pm_free(X509 *x)
{
x509_pm_unload(x);
ssl_free(x->x509_pm);
x->x509_pm = NULL;
@ -450,7 +439,7 @@ failed1:
return -1;
}
int pkey_pm_new(EVP_PKEY *pkey)
int pkey_pm_new(EVP_PKEY *pk, EVP_PKEY *m_pkey)
{
struct pkey_pm *pkey_pm;
@ -458,14 +447,20 @@ int pkey_pm_new(EVP_PKEY *pkey)
if (!pkey_pm)
return -1;
pkey->pkey_pm = pkey_pm;
pk->pkey_pm = pkey_pm;
if (m_pkey) {
struct pkey_pm *m_pkey_pm = (struct pkey_pm *)m_pkey->pkey_pm;
pkey_pm->ex_pkey = m_pkey_pm->pkey;
}
return 0;
}
void pkey_pm_unload(EVP_PKEY *pkey)
void pkey_pm_free(EVP_PKEY *pk)
{
struct pkey_pm *pkey_pm = (struct pkey_pm *)pkey->pkey_pm;
struct pkey_pm *pkey_pm = (struct pkey_pm *)pk->pkey_pm;
if (pkey_pm->pkey) {
mbedtls_pk_free(pkey_pm->pkey);
@ -473,21 +468,16 @@ void pkey_pm_unload(EVP_PKEY *pkey)
ssl_free(pkey_pm->pkey);
pkey_pm->pkey = NULL;
}
ssl_free(pk->pkey_pm);
pk->pkey_pm = NULL;
}
void pkey_pm_free(EVP_PKEY *pkey)
{
pkey_pm_unload(pkey);
ssl_free(pkey->pkey_pm);
pkey->pkey_pm = NULL;
}
int pkey_pm_load(EVP_PKEY *pkey, const unsigned char *buffer, int len)
int pkey_pm_load(EVP_PKEY *pk, const unsigned char *buffer, int len)
{
int ret;
unsigned char *load_buf;
struct pkey_pm *pkey_pm = (struct pkey_pm *)pkey->pkey_pm;
struct pkey_pm *pkey_pm = (struct pkey_pm *)pk->pkey_pm;
if (!pkey_pm->pkey) {
pkey_pm->pkey = ssl_malloc(sizeof(mbedtls_pk_context));