[+] fix secret_check
This commit is contained in:
		
							parent
							
								
									ff786e3ce6
								
							
						
					
					
						commit
						06e79d0679
					
				@ -1,12 +1,14 @@
 | 
				
			|||||||
import base64
 | 
					import base64
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import cryptography.hazmat.primitives.kdf.scrypt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import (Literal, overload, Optional,)
 | 
					from typing import (Literal, overload, Optional,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PasswordUtils:
 | 
					class PasswordUtils:
 | 
				
			||||||
  @overload
 | 
					  @overload
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
  def encrypt(
 | 
					  def secret_hash(
 | 
				
			||||||
    cls,
 | 
					    cls,
 | 
				
			||||||
    secret: str,
 | 
					    secret: str,
 | 
				
			||||||
    mode: Literal['base64'],
 | 
					    mode: Literal['base64'],
 | 
				
			||||||
@ -14,34 +16,28 @@ class PasswordUtils:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  @overload
 | 
					  @overload
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
  def encrypt(
 | 
					  def secret_hash(
 | 
				
			||||||
    cls,
 | 
					    cls,
 | 
				
			||||||
    secret: str,
 | 
					    secret: str,
 | 
				
			||||||
    mode: Literal['bytes'],
 | 
					    mode: Literal['bytes'],
 | 
				
			||||||
  ) -> tuple[bytes, bytes]: ...
 | 
					  ) -> tuple[bytes, bytes]: ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
  def encrypt(
 | 
					  def secret_hash(
 | 
				
			||||||
    cls,
 | 
					    cls,
 | 
				
			||||||
    secret: str,
 | 
					    secret: str | bytes,
 | 
				
			||||||
    mode: Literal['bytes', 'base64'],
 | 
					    mode: Literal['bytes', 'base64'],
 | 
				
			||||||
    salt: Optional[bytes] = None,
 | 
					    salt: Optional[bytes] = None,
 | 
				
			||||||
  ) -> tuple[str, str] | tuple[bytes, bytes]:
 | 
					  ) -> tuple[str, str] | tuple[bytes, bytes]:
 | 
				
			||||||
    from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if salt is None:
 | 
					    if salt is None:
 | 
				
			||||||
      salt = os.urandom(16)
 | 
					      salt = os.urandom(16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if isinstance(secret, str):
 | 
				
			||||||
 | 
					      secret = secret.encode('utf-8')
 | 
				
			||||||
    # derive
 | 
					    # derive
 | 
				
			||||||
    kdf = Scrypt(
 | 
					    kdf = cls._scrypt_init(salt=salt)
 | 
				
			||||||
      salt=salt,
 | 
					 | 
				
			||||||
      length=32,
 | 
					 | 
				
			||||||
      n=2**14,
 | 
					 | 
				
			||||||
      r=8,
 | 
					 | 
				
			||||||
      p=1,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    hashed_secret = kdf.derive(secret.encode('utf-8'))
 | 
					    hashed_secret = kdf.derive(secret)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if mode == 'bytes':
 | 
					    if mode == 'bytes':
 | 
				
			||||||
      return (salt, hashed_secret)
 | 
					      return (salt, hashed_secret)
 | 
				
			||||||
@ -54,14 +50,39 @@ class PasswordUtils:
 | 
				
			|||||||
    else:
 | 
					    else:
 | 
				
			||||||
      raise NotImplementedError
 | 
					      raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # # verify
 | 
					  @classmethod
 | 
				
			||||||
    # kdf = Scrypt(
 | 
					  def _scrypt_init(
 | 
				
			||||||
    #   salt=salt,
 | 
					    cls,
 | 
				
			||||||
    #   length=32,
 | 
					    salt: bytes
 | 
				
			||||||
    #   n=2**14,
 | 
					  ) -> cryptography.hazmat.primitives.kdf.scrypt.Scrypt:
 | 
				
			||||||
    #   r=8,
 | 
					    return cryptography.hazmat.primitives.kdf.scrypt.Scrypt(
 | 
				
			||||||
    #   p=1,
 | 
					      salt=salt,
 | 
				
			||||||
    # )
 | 
					      length=32,
 | 
				
			||||||
    #
 | 
					      n=2**14,
 | 
				
			||||||
    # kdf.verify(b"my great password", key)
 | 
					      r=8,
 | 
				
			||||||
    #
 | 
					      p=1,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def secret_check(
 | 
				
			||||||
 | 
					    cls,
 | 
				
			||||||
 | 
					    secret: str,
 | 
				
			||||||
 | 
					    salt: str | bytes,
 | 
				
			||||||
 | 
					    hashed_secret: str | bytes,
 | 
				
			||||||
 | 
					  ) -> bool:
 | 
				
			||||||
 | 
					    if isinstance(salt, str):
 | 
				
			||||||
 | 
					      salt = base64.b64decode(salt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if isinstance(secret, str):
 | 
				
			||||||
 | 
					      secret = secret.encode('utf-8')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if isinstance(hashed_secret, str):
 | 
				
			||||||
 | 
					      hashed_secret = base64.b64decode(hashed_secret)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kdf = cls._scrypt_init(salt=salt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					      kdf.verify(secret, hashed_secret)
 | 
				
			||||||
 | 
					      return True
 | 
				
			||||||
 | 
					    except cryptography.exceptions.InvalidKey:
 | 
				
			||||||
 | 
					      return False
 | 
				
			||||||
 | 
				
			|||||||
@ -6,15 +6,24 @@ class TestCrypto(unittest.TestCase):
 | 
				
			|||||||
    def test_password_utils(self) -> None:
 | 
					    def test_password_utils(self) -> None:
 | 
				
			||||||
        salt = b'asdfasdfasdf'
 | 
					        salt = b'asdfasdfasdf'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        encrypt_res = crypto.PasswordUtils.encrypt(
 | 
					        secret = 'blah'
 | 
				
			||||||
            'blah',
 | 
					
 | 
				
			||||||
 | 
					        hash_res = crypto.PasswordUtils.secret_hash(
 | 
				
			||||||
 | 
					            secret,
 | 
				
			||||||
            mode='bytes',
 | 
					            mode='bytes',
 | 
				
			||||||
            salt=salt,
 | 
					            salt=salt,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            encrypt_res,
 | 
					            hash_res,
 | 
				
			||||||
            (
 | 
					            (
 | 
				
			||||||
                salt,
 | 
					                salt,
 | 
				
			||||||
                b'\xdak\xd15\xfa\x8e\xc8\r\xc3\xd2c\xf1m\xb0\xbf\xe6\x98\x01$!j\xc8\xc0Hh\x84\xea,\x91\x8b\x08\xce',
 | 
					                b'\xdak\xd15\xfa\x8e\xc8\r\xc3\xd2c\xf1m\xb0\xbf\xe6\x98\x01$!j\xc8\xc0Hh\x84\xea,\x91\x8b\x08\xce',
 | 
				
			||||||
            ),
 | 
					            ),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        check_res = crypto.PasswordUtils.secret_check(
 | 
				
			||||||
 | 
					            secret,
 | 
				
			||||||
 | 
					            *hash_res,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertTrue(check_res)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user