[+] fix secret_check

This commit is contained in:
Siarhei Siniak 2025-03-04 18:38:52 +03:00
parent ff786e3ce6
commit 06e79d0679
2 changed files with 58 additions and 28 deletions

@ -1,12 +1,14 @@
import base64
import os
import cryptography.hazmat.primitives.kdf.scrypt
from typing import (Literal, overload, Optional,)
class PasswordUtils:
@overload
@classmethod
def encrypt(
def secret_hash(
cls,
secret: str,
mode: Literal['base64'],
@ -14,34 +16,28 @@ class PasswordUtils:
@overload
@classmethod
def encrypt(
def secret_hash(
cls,
secret: str,
mode: Literal['bytes'],
) -> tuple[bytes, bytes]: ...
@classmethod
def encrypt(
def secret_hash(
cls,
secret: str,
secret: str | bytes,
mode: Literal['bytes', 'base64'],
salt: Optional[bytes] = None,
) -> tuple[str, str] | tuple[bytes, bytes]:
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
if salt is None:
salt = os.urandom(16)
if isinstance(secret, str):
secret = secret.encode('utf-8')
# derive
kdf = Scrypt(
salt=salt,
length=32,
n=2**14,
r=8,
p=1,
)
kdf = cls._scrypt_init(salt=salt)
hashed_secret = kdf.derive(secret.encode('utf-8'))
hashed_secret = kdf.derive(secret)
if mode == 'bytes':
return (salt, hashed_secret)
@ -54,14 +50,39 @@ class PasswordUtils:
else:
raise NotImplementedError
# # verify
# kdf = Scrypt(
# salt=salt,
# length=32,
# n=2**14,
# r=8,
# p=1,
# )
#
# kdf.verify(b"my great password", key)
#
@classmethod
def _scrypt_init(
cls,
salt: bytes
) -> cryptography.hazmat.primitives.kdf.scrypt.Scrypt:
return cryptography.hazmat.primitives.kdf.scrypt.Scrypt(
salt=salt,
length=32,
n=2**14,
r=8,
p=1,
)
@classmethod
def secret_check(
cls,
secret: str,
salt: str | bytes,
hashed_secret: str | bytes,
) -> bool:
if isinstance(salt, str):
salt = base64.b64decode(salt)
if isinstance(secret, str):
secret = secret.encode('utf-8')
if isinstance(hashed_secret, str):
hashed_secret = base64.b64decode(hashed_secret)
kdf = cls._scrypt_init(salt=salt)
try:
kdf.verify(secret, hashed_secret)
return True
except cryptography.exceptions.InvalidKey:
return False

@ -6,15 +6,24 @@ class TestCrypto(unittest.TestCase):
def test_password_utils(self) -> None:
salt = b'asdfasdfasdf'
encrypt_res = crypto.PasswordUtils.encrypt(
'blah',
secret = 'blah'
hash_res = crypto.PasswordUtils.secret_hash(
secret,
mode='bytes',
salt=salt,
)
self.assertEqual(
encrypt_res,
hash_res,
(
salt,
b'\xdak\xd15\xfa\x8e\xc8\r\xc3\xd2c\xf1m\xb0\xbf\xe6\x98\x01$!j\xc8\xc0Hh\x84\xea,\x91\x8b\x08\xce',
),
)
check_res = crypto.PasswordUtils.secret_check(
secret,
*hash_res,
)
self.assertTrue(check_res)