Skip to content

Commit bb66356

Browse files
authored
ENH: Add Selectable (#305)
* ENH: Add Selectable closes #304 * TST: Use mocking to get around sqlalchemy typing
1 parent ebcb807 commit bb66356

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

pandas-stubs/io/sql.pyi

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ from typing import (
55
Generator,
66
Iterable,
77
Literal,
8+
Union,
89
overload,
910
)
1011

1112
from pandas.core.base import PandasObject
1213
from pandas.core.frame import DataFrame
1314
import sqlalchemy.engine
15+
import sqlalchemy.sql.expression
1416

1517
from pandas._typing import (
1618
DtypeArg,
@@ -20,10 +22,16 @@ from pandas._typing import (
2022
# TODO: Remove after switch to 1.5.x, moved to pandas.errors
2123
class DatabaseError(IOError): ...
2224

25+
_SQLConnection = Union[
26+
str,
27+
sqlalchemy.engine.Connectable,
28+
sqlite3.Connection,
29+
]
30+
2331
@overload
2432
def read_sql_table(
2533
table_name: str,
26-
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
34+
con: _SQLConnection,
2735
schema: str | None = ...,
2836
index_col: str | list[str] | None = ...,
2937
coerce_float: bool = ...,
@@ -35,7 +43,7 @@ def read_sql_table(
3543
@overload
3644
def read_sql_table(
3745
table_name: str,
38-
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
46+
con: _SQLConnection,
3947
schema: str | None = ...,
4048
index_col: str | list[str] | None = ...,
4149
coerce_float: bool = ...,
@@ -45,8 +53,8 @@ def read_sql_table(
4553
) -> DataFrame: ...
4654
@overload
4755
def read_sql_query(
48-
sql: str,
49-
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
56+
sql: str | sqlalchemy.sql.expression.Selectable,
57+
con: _SQLConnection,
5058
index_col: str | list[str] | None = ...,
5159
coerce_float: bool = ...,
5260
params: list[str] | tuple[str, ...] | dict[str, str] | None = ...,
@@ -57,8 +65,8 @@ def read_sql_query(
5765
) -> Generator[DataFrame, None, None]: ...
5866
@overload
5967
def read_sql_query(
60-
sql: str,
61-
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
68+
sql: str | sqlalchemy.sql.expression.Selectable,
69+
con: _SQLConnection,
6270
index_col: str | list[str] | None = ...,
6371
coerce_float: bool = ...,
6472
params: list[str] | tuple[str, ...] | dict[str, str] | None = ...,
@@ -68,8 +76,8 @@ def read_sql_query(
6876
) -> DataFrame: ...
6977
@overload
7078
def read_sql(
71-
sql: str,
72-
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
79+
sql: str | sqlalchemy.sql.expression.Selectable,
80+
con: _SQLConnection,
7381
index_col: str | list[str] | None = ...,
7482
coerce_float: bool = ...,
7583
params: list[str] | tuple[str, ...] | dict[str, str] | None = ...,
@@ -80,8 +88,8 @@ def read_sql(
8088
) -> Generator[DataFrame, None, None]: ...
8189
@overload
8290
def read_sql(
83-
sql: str,
84-
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
91+
sql: str | sqlalchemy.sql.expression.Selectable,
92+
con: _SQLConnection,
8593
index_col: str | list[str] | None = ...,
8694
coerce_float: bool = ...,
8795
params: list[str] | tuple[str, ...] | dict[str, str] | None = ...,

tests/test_io.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
from pandas._testing import ensure_clean
4343
import pytest
4444
import sqlalchemy
45+
import sqlalchemy.ext.declarative
46+
import sqlalchemy.orm
47+
import sqlalchemy.orm.decl_api
4548
from typing_extensions import assert_type
4649

4750
from tests import (
@@ -795,3 +798,24 @@ def test_csv_quoting():
795798
assert_type(DF.to_csv(path, quoting=csv.QUOTE_NONNUMERIC), None), type(None)
796799
)
797800
check(assert_type(DF.to_csv(path, quoting=csv.QUOTE_MINIMAL), None), type(None))
801+
802+
803+
def test_sqlalchemy_selectable() -> None:
804+
with ensure_clean() as path:
805+
db_uri = "sqlite:///" + path
806+
engine = sqlalchemy.create_engine(db_uri)
807+
808+
if TYPE_CHECKING:
809+
# Just type checking since underlying dB does not exist
810+
class Base(metaclass=sqlalchemy.orm.decl_api.DeclarativeMeta):
811+
__abstract__ = True
812+
813+
class Temp(Base):
814+
__tablename__ = "part"
815+
quantity = sqlalchemy.Column(sqlalchemy.Integer)
816+
817+
Session = sqlalchemy.orm.sessionmaker(engine)
818+
with Session() as session:
819+
pd.read_sql(
820+
session.query(Temp.quantity).statement, session.connection()
821+
)

0 commit comments

Comments
 (0)