Skip to content

Commit f188058

Browse files
committed
第七章 FastAPI的数据库操作和多应用的目录结构设计
1 parent 2c466cf commit f188058

File tree

9 files changed

+225
-159
lines changed

9 files changed

+225
-159
lines changed

coronavirus.sqlite3

24 KB
Binary file not shown.

coronavirus/crud.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/usr/bin/python3
2+
# -*- coding:utf-8 -*-
3+
# __author__ = '__Jack__'
4+
5+
from sqlalchemy.orm import Session
6+
7+
from coronavirus import models, schemas
8+
9+
10+
def get_city(db: Session, city_id: int):
11+
return db.query(models.City).filter(models.City.id == city_id).first()
12+
13+
14+
def get_city_by_name(db: Session, name: str):
15+
return db.query(models.City).filter(models.City.province == name).first()
16+
17+
18+
def get_cities(db: Session, skip: int = 0, limit: int = 10):
19+
return db.query(models.City).offset(skip).limit(limit).all()
20+
21+
22+
def create_city(db: Session, city: schemas.CreateCity):
23+
db_city = models.City(**city.dict())
24+
db.add(db_city)
25+
db.commit()
26+
db.refresh(db_city)
27+
return db_city
28+
29+
30+
def get_data(db: Session, skip: int = 0, limit: int = 10):
31+
return db.query(models.Data).offset(skip).limit(limit).all()
32+
33+
34+
def create_city_data(db: Session, data: schemas.CreateData, city_id: int):
35+
db_data = models.Data(**data.dict(), city_id=city_id)
36+
db.add(db_data)
37+
db.commit()
38+
db.refresh(db_data)
39+
return db_data

coronavirus/database.py

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,74 +2,22 @@
22
# -*- coding:utf-8 -*-
33
# __author__ = '__Jack__'
44

5-
from sqlalchemy import create_engine, Column, String, Integer, BigInteger, Date, DateTime, ForeignKey, func
5+
from sqlalchemy import create_engine
66
from sqlalchemy.ext.declarative import declarative_base
7-
from sqlalchemy.orm import sessionmaker, relationship
7+
from sqlalchemy.orm import sessionmaker
88

9-
SQLALCHEMY_DATABASE_URL = 'sqlite:///./coronavirus.db'
9+
SQLALCHEMY_DATABASE_URL = 'sqlite:///./coronavirus.sqlite3'
1010
# SQLALCHEMY_DATABASE_URL = "postgresql://username:password@host:port/database_name" # MySQL或PostgreSQL的连接方法
1111

1212
engine = create_engine(
1313
# echo=True表示引擎将用repr()函数记录所有语句及其参数列表到日志
14-
# 由于SQLAlchemy是多线程,指定check_same_thread=False来让建立的对象任意线程都可使用
14+
# 由于SQLAlchemy是多线程,指定check_same_thread=False来让建立的对象任意线程都可使用。这个参数只在用SQLite数据库时设置
1515
SQLALCHEMY_DATABASE_URL, encoding='utf-8', echo=True, connect_args={'check_same_thread': False}
1616
)
1717

18-
# 在SQLAlchemy中,CRUD都是通过会话(session)进行的,所以我们必须要先创建会话
18+
# 在SQLAlchemy中,CRUD都是通过会话(session)进行的,所以我们必须要先创建会话,每一个SessionLocal实例就是一个数据库session
1919
# flush()是指发送数据库语句到数据库,但数据库不一定执行写入磁盘;commit()是指提交事务,将变更保存到数据库文件
2020
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=True)
2121

2222
# 创建基本映射类
2323
Base = declarative_base(bind=engine, name='Base')
24-
25-
26-
class City(Base):
27-
__tablename__ = 'city'
28-
29-
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
30-
province = Column(String(100), unique=True, nullable=False, comment='省/直辖市')
31-
country = Column(String(100), nullable=False, comment='国家')
32-
country_code = Column(String(100), nullable=False, comment='国家代码')
33-
country_population = Column(BigInteger, nullable=False, comment='国家人口')
34-
data = relationship('Data', backref='city') # 'Data'是关联的类名;backref来指定反向访问的属性名称
35-
36-
created_at = Column(DateTime, server_default=func.now(), comment='创建时间')
37-
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment='更新时间')
38-
39-
__mapper_args__ = {"order_by": country_code} # 默认是正序,倒序加上.desc()方法
40-
41-
def __repr__(self):
42-
return f'{self.country}_{self.province}'
43-
44-
45-
class Data(Base):
46-
__tablename__ = 'data'
47-
48-
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
49-
city_id = Column(Integer, ForeignKey('city.id'), ondelete='CASCADE', comment='所属省/直辖市') # ForeignKey里的字符串格式不是类名.属性名,而是表名.字段名
50-
date = Column(Date, nullable=False, comment='数据日期')
51-
confirmed = Column(BigInteger, default=0, nullable=False, comment='确诊数量')
52-
deaths = Column(BigInteger, default=0, nullable=False, comment='死亡数量')
53-
recovered = Column(BigInteger, default=0, nullable=False, comment='痊愈数量')
54-
city = relationship('City', backref='data') # 'City'是关联的类名;backref来指定反向访问的属性名称
55-
56-
created_at = Column(DateTime, server_default=func.now(), comment='创建时间')
57-
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment='更新时间')
58-
59-
__mapper_args__ = {"order_by": confirmed}
60-
61-
def __repr__(self):
62-
return f'{repr(self.date)}:确诊{self.confirmed}例'
63-
64-
65-
""" 附上三个SQLAlchemy教程
66-
67-
SQLAlchemy的基本操作大全
68-
http://www.taodudu.cc/news/show-175725.html
69-
70-
Python3+SQLAlchemy+Sqlite3实现ORM教程
71-
https://www.cnblogs.com/jiangxiaobo/p/12350561.html
72-
73-
SQLAlchemy基础知识 Autoflush和Autocommit
74-
https://zhuanlan.zhihu.com/p/48994990
75-
"""

coronavirus/main.py

Lines changed: 49 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2,115 +2,63 @@
22
# -*- coding:utf-8 -*-
33
# __author__ = '__Jack__'
44

5-
from fastapi import Request, Depends, BackgroundTasks, APIRouter
6-
from pydantic import BaseModel
7-
from sqlalchemy.orm import Session
8-
9-
10-
# from coronavirus.database import engine, SessionLocal, Base
11-
5+
from typing import List
126

7+
from fastapi import APIRouter, Depends, HTTPException
138
from fastapi.staticfiles import StaticFiles
149
from fastapi.templating import Jinja2Templates
10+
from sqlalchemy.orm import Session
1511

12+
from coronavirus import crud, schemas
13+
from coronavirus.database import engine, Base, SessionLocal
1614

1715
application = APIRouter()
1816

19-
17+
# mount表示将某个目录下一个完全独立的应用挂载过来,这个不会在API交互文档中显示
2018
application.mount('/static', StaticFiles(directory='./coronavirus/static'), name='static')
2119
templates = Jinja2Templates(directory='./coronavirus/templates')
2220

21+
Base.metadata.create_all(bind=engine)
22+
23+
24+
def get_db():
25+
db = SessionLocal()
26+
try:
27+
yield db
28+
finally:
29+
db.close()
30+
31+
32+
@application.post("/create_city", response_model=schemas.ReadCity)
33+
def create_city(city: schemas.CreateCity, db: Session = Depends(get_db)):
34+
db_city = crud.get_city_by_name(db, name=city.province)
35+
if db_city:
36+
raise HTTPException(status_code=400, detail="City already registered")
37+
return crud.create_city(db=db, city=city)
38+
39+
40+
@application.get("/get_city/{city}", response_model=schemas.ReadCity)
41+
def get_city(city: str, db: Session = Depends(get_db)):
42+
db_city = crud.get_city_by_name(db, name=city)
43+
if db_city is None:
44+
raise HTTPException(status_code=404, detail="City not found")
45+
return db_city
46+
47+
48+
@application.get("/get_cities", response_model=List[schemas.ReadCity])
49+
def get_cities(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
50+
cities = crud.get_cities(db, skip=skip, limit=limit)
51+
return cities
52+
53+
54+
@application.post("/create_data", response_model=schemas.ReadData)
55+
def create_data_for_city(city: str, data: schemas.CreateData, db: Session = Depends(get_db)):
56+
db_city = crud.get_city_by_name(db, name=city)
57+
data = crud.create_city_data(db=db, data=data, city_id=db_city.id)
58+
return data
59+
2360

24-
# Base.metadata.create_all(bind=engine)
25-
26-
27-
@application.get("/")
28-
async def index(request: Request):
29-
return templates.TemplateResponse("index.html", {"request": request, "title": "This is title"}) # {"request": request}是必须的
30-
31-
32-
@application.get("/coronavirus")
33-
async def coronavirus():
34-
"""This is a simple tutorial"""
35-
return {"message": {"This is another route"}}
36-
37-
38-
# class StockRequest(BaseModel):
39-
# symbol: str
40-
#
41-
#
42-
# def get_db():
43-
# try:
44-
# db = SessionLocal()
45-
# yield db
46-
# finally:
47-
# db.close()
48-
#
49-
#
50-
# @application.get("/")
51-
# def home(request: Request, forward_pe=None, dividend_yield=None, ma50=None, ma200=None, db: Session = Depends(get_db)):
52-
# """
53-
# displays the stock screener dashboard / homepage
54-
# :return:
55-
# """
56-
# stocks = db.query(Stock)
57-
#
58-
# if forward_pe:
59-
# stocks = stocks.filter(Stock.forward_pe < forward_pe)
60-
#
61-
# if dividend_yield:
62-
# stocks = stocks.filter(Stock.dividend_yield > dividend_yield)
63-
#
64-
# if ma50:
65-
# stocks = stocks.filter(Stock.price > Stock.ma50)
66-
#
67-
# if ma200:
68-
# stocks = stocks.filter(Stock.price > Stock.ma200)
69-
#
70-
# return templates.TemplateResponse("home.html", {
71-
# "request": request,
72-
# "stocks": stocks
73-
# })
74-
#
75-
#
76-
# def fetch_stock_data(id_: int):
77-
# """
78-
# fetch data from yahoo finance
79-
# :param id_:
80-
# :return:
81-
# """
82-
# db = SessionLocal()
83-
# stock = db.query(Stock).filter(Stock.id == id_).first()
84-
#
85-
# yahoo_data = yf.Ticker(stock.symbol)
86-
#
87-
# stock.ma50 = yahoo_data.info["fiftyDayAverage"]
88-
# stock.ma200 = yahoo_data.info["twoHundredDayAverage"]
89-
# stock.price = yahoo_data.info["previousClose"]
90-
# stock.forward_pe = yahoo_data.info["forwardPE"]
91-
# stock.forward_eps = yahoo_data.info["forwardEps"]
92-
# if yahoo_data.info["dividendYield"]:
93-
# stock.dividend_yield = yahoo_data.info["dividendYield"] * 100
94-
#
95-
# db.add(stock)
96-
# db.commit()
97-
#
98-
#
99-
# @application.post("/stock")
100-
# async def create_stock(stock_request: StockRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
101-
# """
102-
# created a stock and stores it in the database
103-
# :return:
104-
# """
105-
# stock = Stock()
106-
# stock.symbol = stock_request.symbol
107-
#
108-
# db.add(stock)
109-
# db.commit()
110-
#
111-
# background_tasks.add_task(fetch_stock_data, stock.id)
112-
#
113-
# return {
114-
# "code": "success",
115-
# "message": "stock created"
116-
# }
61+
@application.get("/get_data", response_model=List[schemas.ReadData])
62+
def get_data(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
63+
data = crud.get_data(db, skip=skip, limit=limit)
64+
return data

coronavirus/models.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/python3
2+
# -*- coding:utf-8 -*-
3+
# __author__ = '__Jack__'
4+
5+
from sqlalchemy import Column, String, Integer, BigInteger, Date, DateTime, ForeignKey, func
6+
from sqlalchemy.orm import relationship
7+
8+
from .database import Base
9+
10+
11+
class City(Base):
12+
__tablename__ = 'city' # 数据表的表名
13+
14+
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
15+
province = Column(String(100), unique=True, nullable=False, comment='省/直辖市')
16+
country = Column(String(100), nullable=False, comment='国家')
17+
country_code = Column(String(100), nullable=False, comment='国家代码')
18+
country_population = Column(BigInteger, nullable=False, comment='国家人口')
19+
data = relationship('Data', back_populates='city') # 'Data'是关联的类名;back_populates来指定反向访问的属性名称
20+
21+
created_at = Column(DateTime, server_default=func.now(), comment='创建时间')
22+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment='更新时间')
23+
24+
__mapper_args__ = {"order_by": country_code} # 默认是正序,倒序加上.desc()方法
25+
26+
def __repr__(self):
27+
return f'{self.country}_{self.province}'
28+
29+
30+
class Data(Base):
31+
__tablename__ = 'data'
32+
33+
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
34+
city_id = Column(Integer, ForeignKey('city.id'), comment='所属省/直辖市') # ForeignKey里的字符串格式不是类名.属性名,而是表名.字段名
35+
date = Column(Date, nullable=False, comment='数据日期')
36+
confirmed = Column(BigInteger, default=0, nullable=False, comment='确诊数量')
37+
deaths = Column(BigInteger, default=0, nullable=False, comment='死亡数量')
38+
recovered = Column(BigInteger, default=0, nullable=False, comment='痊愈数量')
39+
city = relationship('City', back_populates='data') # 'City'是关联的类名;back_populates来指定反向访问的属性名称
40+
41+
created_at = Column(DateTime, server_default=func.now(), comment='创建时间')
42+
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment='更新时间')
43+
44+
__mapper_args__ = {"order_by": confirmed}
45+
46+
def __repr__(self):
47+
return f'{repr(self.date)}:确诊{self.confirmed}例'
48+
49+
50+
""" 附上三个SQLAlchemy教程
51+
52+
SQLAlchemy的基本操作大全
53+
http://www.taodudu.cc/news/show-175725.html
54+
55+
Python3+SQLAlchemy+Sqlite3实现ORM教程
56+
https://www.cnblogs.com/jiangxiaobo/p/12350561.html
57+
58+
SQLAlchemy基础知识 Autoflush和Autocommit
59+
https://zhuanlan.zhihu.com/p/48994990
60+
"""

coronavirus/schemas.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/python3
2+
# -*- coding:utf-8 -*-
3+
# __author__ = '__Jack__'
4+
5+
from datetime import date as date_
6+
from datetime import datetime
7+
from typing import List
8+
9+
from pydantic import BaseModel
10+
11+
12+
class CreateData(BaseModel):
13+
date: date_
14+
confirmed: int = 0
15+
deaths: int = 0
16+
recovered: int = 0
17+
18+
19+
class CreateCity(BaseModel):
20+
province: str
21+
country: str
22+
country_code: str
23+
country_population: int
24+
data: List[CreateData] = []
25+
26+
27+
class ReadData(CreateData):
28+
id: int
29+
city_id: int
30+
updated_at: datetime
31+
created_at: datetime
32+
33+
class Config:
34+
orm_mode = True
35+
36+
37+
class ReadCity(CreateCity):
38+
id: int
39+
updated_at: datetime
40+
created_at: datetime
41+
42+
class Config:
43+
orm_mode = True

0 commit comments

Comments
 (0)