diff options
Diffstat (limited to 'crypto/sm2.c')
| -rw-r--r-- | crypto/sm2.c | 106 | 
1 files changed, 70 insertions, 36 deletions
diff --git a/crypto/sm2.c b/crypto/sm2.c index ed9307dac3d1..285b3cb7c0bc 100644 --- a/crypto/sm2.c +++ b/crypto/sm2.c @@ -13,11 +13,14 @@  #include <crypto/internal/akcipher.h>  #include <crypto/akcipher.h>  #include <crypto/hash.h> -#include <crypto/sm3.h>  #include <crypto/rng.h>  #include <crypto/sm2.h>  #include "sm2signature.asn1.h" +/* The default user id as specified in GM/T 0009-2012 */ +#define SM2_DEFAULT_USERID "1234567812345678" +#define SM2_DEFAULT_USERID_LEN 16 +  #define MPI_NBYTES(m)   ((mpi_get_nbits(m) + 7) / 8)  struct ecc_domain_parms { @@ -60,6 +63,9 @@ static const struct ecc_domain_parms sm2_ecp = {  	.h = 1  }; +static int __sm2_set_pub_key(struct mpi_ec_ctx *ec, +			     const void *key, unsigned int keylen); +  static int sm2_ec_ctx_init(struct mpi_ec_ctx *ec)  {  	const struct ecc_domain_parms *ecp = &sm2_ecp; @@ -213,12 +219,13 @@ int sm2_get_signature_s(void *context, size_t hdrlen, unsigned char tag,  	return 0;  } -static int sm2_z_digest_update(struct sm3_state *sctx, -			MPI m, unsigned int pbytes) +static int sm2_z_digest_update(struct shash_desc *desc, +			       MPI m, unsigned int pbytes)  {  	static const unsigned char zero[32];  	unsigned char *in;  	unsigned int inlen; +	int err;  	in = mpi_get_buffer(m, &inlen, NULL);  	if (!in) @@ -226,21 +233,22 @@ static int sm2_z_digest_update(struct sm3_state *sctx,  	if (inlen < pbytes) {  		/* padding with zero */ -		sm3_update(sctx, zero, pbytes - inlen); -		sm3_update(sctx, in, inlen); +		err = crypto_shash_update(desc, zero, pbytes - inlen) ?: +		      crypto_shash_update(desc, in, inlen);  	} else if (inlen > pbytes) {  		/* skip the starting zero */ -		sm3_update(sctx, in + inlen - pbytes, pbytes); +		err = crypto_shash_update(desc, in + inlen - pbytes, pbytes);  	} else { -		sm3_update(sctx, in, inlen); +		err = crypto_shash_update(desc, in, inlen);  	}  	kfree(in); -	return 0; +	return err;  } -static int sm2_z_digest_update_point(struct sm3_state *sctx, -		MPI_POINT point, struct mpi_ec_ctx *ec, unsigned int pbytes) +static int sm2_z_digest_update_point(struct shash_desc *desc, +				     MPI_POINT point, struct mpi_ec_ctx *ec, +				     unsigned int pbytes)  {  	MPI x, y;  	int ret = -EINVAL; @@ -248,50 +256,68 @@ static int sm2_z_digest_update_point(struct sm3_state *sctx,  	x = mpi_new(0);  	y = mpi_new(0); -	if (!mpi_ec_get_affine(x, y, point, ec) && -	    !sm2_z_digest_update(sctx, x, pbytes) && -	    !sm2_z_digest_update(sctx, y, pbytes)) -		ret = 0; +	ret = mpi_ec_get_affine(x, y, point, ec) ? -EINVAL : +	      sm2_z_digest_update(desc, x, pbytes) ?: +	      sm2_z_digest_update(desc, y, pbytes);  	mpi_free(x);  	mpi_free(y);  	return ret;  } -int sm2_compute_z_digest(struct crypto_akcipher *tfm, -			const unsigned char *id, size_t id_len, -			unsigned char dgst[SM3_DIGEST_SIZE]) +int sm2_compute_z_digest(struct shash_desc *desc, +			 const void *key, unsigned int keylen, void *dgst)  { -	struct mpi_ec_ctx *ec = akcipher_tfm_ctx(tfm); -	uint16_t bits_len; -	unsigned char entl[2]; -	struct sm3_state sctx; +	struct mpi_ec_ctx *ec; +	unsigned int bits_len;  	unsigned int pbytes; +	u8 entl[2]; +	int err; -	if (id_len > (USHRT_MAX / 8) || !ec->Q) -		return -EINVAL; +	ec = kmalloc(sizeof(*ec), GFP_KERNEL); +	if (!ec) +		return -ENOMEM; + +	err = __sm2_set_pub_key(ec, key, keylen); +	if (err) +		goto out_free_ec; -	bits_len = (uint16_t)(id_len * 8); +	bits_len = SM2_DEFAULT_USERID_LEN * 8;  	entl[0] = bits_len >> 8;  	entl[1] = bits_len & 0xff;  	pbytes = MPI_NBYTES(ec->p);  	/* ZA = H256(ENTLA | IDA | a | b | xG | yG | xA | yA) */ -	sm3_init(&sctx); -	sm3_update(&sctx, entl, 2); -	sm3_update(&sctx, id, id_len); - -	if (sm2_z_digest_update(&sctx, ec->a, pbytes) || -	    sm2_z_digest_update(&sctx, ec->b, pbytes) || -	    sm2_z_digest_update_point(&sctx, ec->G, ec, pbytes) || -	    sm2_z_digest_update_point(&sctx, ec->Q, ec, pbytes)) -		return -EINVAL; +	err = crypto_shash_init(desc); +	if (err) +		goto out_deinit_ec; -	sm3_final(&sctx, dgst); -	return 0; +	err = crypto_shash_update(desc, entl, 2); +	if (err) +		goto out_deinit_ec; + +	err = crypto_shash_update(desc, SM2_DEFAULT_USERID, +				  SM2_DEFAULT_USERID_LEN); +	if (err) +		goto out_deinit_ec; + +	err = sm2_z_digest_update(desc, ec->a, pbytes) ?: +	      sm2_z_digest_update(desc, ec->b, pbytes) ?: +	      sm2_z_digest_update_point(desc, ec->G, ec, pbytes) ?: +	      sm2_z_digest_update_point(desc, ec->Q, ec, pbytes); +	if (err) +		goto out_deinit_ec; + +	err = crypto_shash_final(desc, dgst); + +out_deinit_ec: +	sm2_ec_ctx_deinit(ec); +out_free_ec: +	kfree(ec); +	return err;  } -EXPORT_SYMBOL(sm2_compute_z_digest); +EXPORT_SYMBOL_GPL(sm2_compute_z_digest);  static int _sm2_verify(struct mpi_ec_ctx *ec, MPI hash, MPI sig_r, MPI sig_s)  { @@ -391,6 +417,14 @@ static int sm2_set_pub_key(struct crypto_akcipher *tfm,  			const void *key, unsigned int keylen)  {  	struct mpi_ec_ctx *ec = akcipher_tfm_ctx(tfm); + +	return __sm2_set_pub_key(ec, key, keylen); + +} + +static int __sm2_set_pub_key(struct mpi_ec_ctx *ec, +			     const void *key, unsigned int keylen) +{  	MPI a;  	int rc;  |