5
5
import sys
6
6
import typing
7
7
from types import TracebackType
8
- from urllib .parse import SplitResult , parse_qsl , urlsplit
8
+ from urllib .parse import SplitResult , parse_qsl , urlsplit , unquote
9
9
10
10
from sqlalchemy import text
11
11
from sqlalchemy .sql import ClauseElement
@@ -389,7 +389,14 @@ def __bool__(self) -> bool:
389
389
390
390
class DatabaseURL :
391
391
def __init__ (self , url : typing .Union [str , "DatabaseURL" ]):
392
- self ._url = str (url )
392
+ if isinstance (url , DatabaseURL ):
393
+ self ._url : str = url ._url
394
+ elif isinstance (url , str ):
395
+ self ._url = url
396
+ else :
397
+ raise TypeError (
398
+ f"Invalid type for DatabaseURL. Expected str or DatabaseURL, got { type (url )} "
399
+ )
393
400
394
401
@property
395
402
def components (self ) -> SplitResult :
@@ -411,13 +418,26 @@ def driver(self) -> str:
411
418
return ""
412
419
return self .components .scheme .split ("+" , 1 )[1 ]
413
420
421
+ @property
422
+ def userinfo (self ) -> typing .Optional [bytes ]:
423
+ if self .components .username :
424
+ info = self .components .username
425
+ if self .components .password :
426
+ info += ":" + self .components .password
427
+ return info .encode ("utf-8" )
428
+ return None
429
+
414
430
@property
415
431
def username (self ) -> typing .Optional [str ]:
416
- return self .components .username
432
+ if self .components .username is None :
433
+ return None
434
+ return unquote (self .components .username )
417
435
418
436
@property
419
437
def password (self ) -> typing .Optional [str ]:
420
- return self .components .password
438
+ if self .components .password is None :
439
+ return None
440
+ return unquote (self .components .password )
421
441
422
442
@property
423
443
def hostname (self ) -> typing .Optional [str ]:
@@ -436,7 +456,7 @@ def database(self) -> str:
436
456
path = self .components .path
437
457
if path .startswith ("/" ):
438
458
path = path [1 :]
439
- return path
459
+ return unquote ( path )
440
460
441
461
@property
442
462
def options (self ) -> dict :
@@ -453,8 +473,8 @@ def replace(self, **kwargs: typing.Any) -> "DatabaseURL":
453
473
):
454
474
hostname = kwargs .pop ("hostname" , self .hostname )
455
475
port = kwargs .pop ("port" , self .port )
456
- username = kwargs .pop ("username" , self .username )
457
- password = kwargs .pop ("password" , self .password )
476
+ username = kwargs .pop ("username" , self .components . username )
477
+ password = kwargs .pop ("password" , self .components . password )
458
478
459
479
netloc = hostname
460
480
if port is not None :
0 commit comments