fix: refactor errors

fastapi-admin
Pierre-Louis Guhur 1 year ago
parent caacaad028
commit f7fc535a89

@ -1,3 +1,5 @@
SECRET=
DB_NAME=mj
DB_HOST=mj_db
DB_PORT=5433

@ -26,10 +26,6 @@ You certainly want to apply databases migrations with:
`docker/test.sh`
## Create databases migrations
`sudo docker/makemigrations.sh`
## Local development
@ -68,6 +64,7 @@ uvicorn app.main:app --reload --env-file .env.local
http://127.0.0.1:8000/redoc
```
If you need to alter the database, you can create new migrations using [alembic](https://alembic.sqlalchemy.org/en/latest/index.html).
## TODO

@ -0,0 +1,47 @@
import json
from collections.abc import Mapping
import typing as t
from jose import jws, JWSError
from . import errors
from .settings import settings
def jws_verify(token: str) -> Mapping[str, t.Any]:
"""
Verify the content of a JWS token
"""
try:
data = jws.verify(token, settings.secret, algorithms=["HS256"])
except JWSError:
raise errors.UnauthorizedError("Can not decode token")
if not isinstance(data, bytes):
raise errors.BadRequestError("Ununderstandable token")
try:
return json.loads(data)
except json.decoder.JSONDecodeError:
raise errors.BadRequestError("Ununderstandable token")
def create_ballot_token(
vote_ids: int | list[int],
election_id: int,
) -> str:
if isinstance(vote_ids, int):
vote_ids = [vote_ids]
return jws.sign(
{"votes": vote_ids, "election": election_id},
settings.secret,
algorithm="HS256",
)
def create_admin_token(
election_id: int,
) -> str:
return jws.sign(
{"admin": True, "election": election_id},
settings.secret,
algorithm="HS256",
)

@ -1,9 +1,12 @@
from datetime import datetime
from collections import defaultdict
from typing import DefaultDict
from sqlalchemy.orm import Session
from sqlalchemy import func
from sqlalchemy import func, insert
from majority_judgment import majority_judgment
from . import models, schemas, errors
from .settings import settings
from .auth import create_ballot_token, create_admin_token, jws_verify
def get_election(db: Session, election_id: int):
@ -82,6 +85,24 @@ def _create_election_without_candidates_or_grade(
return db_election
def create_invite_tokens(
db: Session,
election_id: int,
num_candidates: int,
num_voters: int,
) -> list[str]:
now = datetime.now()
params = {"date_created": now, "date_modified": now, "election_id": election_id}
db_votes = [models.Vote(**params) for _ in range(num_voters * num_candidates)]
db.bulk_save_objects(db_votes, return_defaults=True)
vote_ids = [int(str(v.id)) for v in db_votes]
tokens = [
create_ballot_token(vote_ids[i::num_candidates], election_id)
for i in range(num_voters)
]
return tokens
def create_election(
db: Session, election: schemas.ElectionCreate
) -> schemas.ElectionGet:
@ -91,25 +112,26 @@ def create_election(
# Then, we add separatly candidates and grades
for candidate in election.candidates:
candidate_rel = schemas.CandidateRelational(
**{**candidate.dict(), "election_id": db_election.id}
)
params = candidate.dict()
params["election_id"] = db_election.id
candidate_rel = schemas.CandidateRelational(**params)
create_candidate(db, candidate_rel, False)
for grade in election.grades:
grade_rel = schemas.GradeRelational(
**{**grade.dict(), "election_id": db_election.id}
)
params = grade.dict()
params["election_id"] = db_election.id
grade_rel = schemas.GradeRelational(**params)
create_grade(db, grade_rel, False)
db.commit()
db.refresh(db_election)
# create_
# TODO JWT token for invites
invites: list[str] = []
invites: list[str] = create_invite_tokens(
db, int(str(db_election.id)), len(election.candidates), election.num_voters
)
# TODO JWT token for admin panel
admin = ""
admin = create_admin_token(int(str(db_election.id)))
election_and_invites = schemas.ElectionAndInvitesGet.from_orm(db_election)
election_and_invites.invites = invites
@ -118,41 +140,67 @@ def create_election(
return election_and_invites
def create_vote(db: Session, vote: schemas.VoteCreate) -> schemas.VoteGet:
db_vote = models.Vote(**vote.dict())
db.add(db_vote)
def create_vote(db: Session, vote: schemas.BallotCreate) -> schemas.BallotGet:
votes = vote.votes
db.commit()
db.refresh(db_vote)
if votes == []:
raise errors.BadRequestError("The ballot contains no vote")
return db_vote
election_id = votes[0].election_id
# Check if the election is open
db_election = get_election(db, election_id)
if db_election is None:
raise errors.NotFoundError("Unknown election.")
if db_election.private:
raise errors.BadRequestError(
"The election is private. You can not create new votes"
)
def get_vote(db: Session, vote_id: int) -> models.Vote:
# TODO check with JWT tokens the authorization
votes_by_id = db.query(models.Vote).filter(models.Vote.id == vote_id)
db_votes = [models.Vote(**v.dict()) for v in vote.votes]
db.bulk_save_objects(db_votes, return_defaults=True)
if votes_by_id.count() > 1:
raise errors.InconsistentDatabaseError(
"votes", f"Several votes have the same primary keys {vote_id}"
)
votes_get = [schemas.VoteGet.from_orm(v) for v in db_votes]
vote_ids = [v.id for v in votes_get]
token = create_ballot_token(vote_ids, election_id)
return schemas.BallotGet(votes=votes_get, token=token)
vote = votes_by_id.first()
if vote is not None:
return vote
votes_by_ref = db.query(models.Vote).filter(models.Vote.ref == vote_id)
def update_vote(db: Session, vote: schemas.BallotUpdate) -> schemas.BallotGet:
votes = vote.votes
token = vote.token
if votes_by_ref.count() > 1:
raise errors.InconsistentDatabaseError(
"votes", f"Several votes have the same reference {vote_id}"
)
if votes == []:
raise errors.BadRequestError("The ballot contains no vote")
election_id = votes[0].election_id
# Check if the election exists
db_election = get_election(db, election_id)
if db_election is None:
raise errors.NotFoundError("Unknown election.")
db_votes = [models.Vote(**v.dict()) for v in vote.votes]
db.bulk_save_objects(db_votes, return_defaults=True)
votes_get = [schemas.VoteGet.from_orm(v) for v in db_votes]
vote_ids = [v.id for v in votes_get]
token = create_ballot_token(vote_ids, election_id)
return schemas.BallotGet(votes=votes_get, token=token)
vote = votes_by_ref.first()
if vote is not None:
return vote
raise errors.NotFoundError("votes")
def get_votes(db: Session, token: str) -> schemas.BallotGet:
data = jws_verify(token)
vote_ids = data["votes"]
election_id = data["election"]
votes = db.query(models.Vote).filter(
models.Vote.id.in_((vote_ids))
& (models.Vote.candidate_id.is_not(None))
& (models.Vote.election_id == election_id)
)
votes_get = [schemas.VoteGet.from_orm(v) for v in votes.all()]
return schemas.BallotGet(token=token, votes=votes_get)
def get_results(db: Session, election_id: int) -> schemas.ResultsGet:
@ -175,7 +223,7 @@ def get_results(db: Session, election_id: int) -> schemas.ResultsGet:
for c, votes in ballots.items()
}
ranking = majority_judgment(merit_profile)
ranking = majority_judgment(merit_profile) # pyright: ignore
results = schemas.ResultsGet.from_orm(db_election)

@ -2,10 +2,12 @@
Utility to handle exceptions
"""
class NotFoundError(Exception):
"""
An item can not be found
"""
def __init__(self, name: str):
self.name = name
@ -14,6 +16,25 @@ class InconsistentDatabaseError(Exception):
"""
An inconsistent value was detected on the database
"""
def __init__(self, name: str, details: str | None = None):
self.name = name
self.details = details
class BadRequestError(Exception):
"""
The request is made inconsistent
"""
def __init__(self, name: str):
self.name = name
class UnauthorizedError(Exception):
"""
The verification could not be verified
"""
def __init__(self, name: str):
self.name = name

@ -1,8 +1,11 @@
import typing as t
import json
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, PlainTextResponse
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
from jose import jwe, jws
from jose.exceptions import JWEError, JWSError
from . import crud, models, schemas, errors
from .database import get_db, engine
@ -34,6 +37,21 @@ async def not_found_exception_handler(request: Request, exc: errors.NotFoundErro
)
@app.exception_handler(errors.UnauthorizedError)
async def unauthorized_exception_handler(request: Request, exc: errors.NotFoundError):
return JSONResponse(
status_code=401, content={"message": "Unautorized", "details": exc.name}
)
@app.exception_handler(errors.BadRequestError)
async def bad_request_exception_handler(request: Request, exc: errors.NotFoundError):
return JSONResponse(
status_code=400,
content={"message": f"Bad Request", "details": exc.name},
)
@app.exception_handler(errors.InconsistentDatabaseError)
async def inconsistent_database_exception_handler(
request: Request, exc: errors.InconsistentDatabaseError
@ -62,15 +80,25 @@ def create_election(election: schemas.ElectionCreate, db: Session = Depends(get_
return crud.create_election(db=db, election=election)
@app.post("/votes", response_model=schemas.VoteGet)
def create_vote(vote: schemas.VoteCreate, db: Session = Depends(get_db)):
return crud.create_vote(db=db, vote=vote)
@app.post("/votes", response_model=schemas.BallotGet)
def create_vote(vote: schemas.BallotCreate, db: Session = Depends(get_db)):
try:
return crud.create_vote(db=db, vote=vote)
except JWSError:
raise errors.UnauthorizedError("Unverified token")
@app.put("/votes", response_model=schemas.BallotGet)
def update_vote(vote: schemas.BallotUpdate, db: Session = Depends(get_db)):
try:
return crud.update_vote(db=db, vote=vote)
except JWSError:
raise errors.UnauthorizedError("Unverified token")
@app.get("/votes/{vote_id}", response_model=schemas.VoteGet)
def get_vote(vote_id: int, db: Session = Depends(get_db)):
# TODO assert with a JWT token that we are allowed to read the vote
return crud.get_vote(db=db, vote_id=vote_id)
@app.get("/votes/{token}", response_model=schemas.BallotGet)
def get_vote(token: str, db: Session = Depends(get_db)):
return crud.get_votes(db=db, token=token)
@app.get("/results/{election_id}", response_model=schemas.ResultsGet)

@ -63,10 +63,10 @@ class Vote(Base):
date_created = Column(DateTime)
date_modified = Column(DateTime)
candidate_id = Column(Integer, ForeignKey("candidates.id"))
candidate_id = Column(Integer, ForeignKey("candidates.id"), nullable=True)
candidate = relationship("Candidate", back_populates="votes")
grade_id = Column(Integer, ForeignKey("grades.id"))
grade_id = Column(Integer, ForeignKey("grades.id"), nullable=True)
grade = relationship("Grade", back_populates="votes")
election_id = Column(Integer, ForeignKey("elections.id"))

@ -92,8 +92,8 @@ class GradeRelational(GradeBase):
class VoteBase(BaseModel):
candidate: CandidateGet
grade: GradeGet
candidate: CandidateGet | None = Field(default=None)
grade: GradeGet | None = Field(default=None)
date_created: datetime = Field(default_factory=datetime.now)
date_modified: datetime = Field(default_factory=datetime.now)
@ -106,6 +106,7 @@ class VoteBase(BaseModel):
class VoteGet(VoteBase):
id: int
election_id: int
token: str = ""
class VoteCreate(BaseModel):
@ -121,6 +122,25 @@ class VoteCreate(BaseModel):
orm_mode = True
class VoteUpdate(VoteCreate):
id: int | None = None
token: str = ""
class BallotGet(BaseModel):
votes: list[VoteGet]
token: str
class BallotCreate(BaseModel):
votes: list[VoteCreate]
class BallotUpdate(BaseModel):
votes: list[VoteUpdate]
token: str
def _in_a_long_time() -> datetime:
"""
Provides the date in the future

@ -1,9 +1,14 @@
import random
import sys
from pydantic import BaseSettings
class Settings(BaseSettings):
sqlite: bool = False
secret: str = ""
aes_key: bytes = b""
postgres_password: str = ""
postgres_user: str = "mj"
postgres_name: str = "mj"
@ -20,4 +25,21 @@ class Settings(BaseSettings):
env_file = ".env"
def get_random_key(length: int, rng: random.Random) -> bytes:
"""
Inspired from https://stackoverflow.com/a/37357035/4986615
"""
if length == 0:
return b""
integer = rng.getrandbits(length * 8)
result = integer.to_bytes(length, sys.byteorder)
return result
settings = Settings()
if settings.secret == "":
raise RuntimeError("Please generate a secret key")
rng = random.Random(settings.secret)
settings.aes_key = get_random_key(16, rng)

@ -143,24 +143,31 @@ def test_create_vote():
# We create votes using the ID
votes = _generate_votes_from_response("id", data)
vote_ids = []
for vote in votes:
response = client.post(f"/votes", json=vote)
assert response.status_code == 200, response.text
data = response.json()
assert data["grade"]["id"] == vote["grade_id"]
assert data["candidate"]["id"] == vote["candidate_id"]
assert data["election_id"] == election_id
vote_ids.append(data["id"])
# Now, we check that we can correctly read them
for vote_id, vote in zip(vote_ids, votes):
response = client.get(f"/votes/{vote_id}")
assert response.status_code == 200, response.text
data = response.json()
assert data["grade"]["id"] == vote["grade_id"]
assert data["candidate"]["id"] == vote["candidate_id"]
assert data["election_id"] == election_id
response = client.post(f"/votes", json={"votes": votes})
assert response.status_code == 200, response.text
data = response.json()
for v1, v2 in zip(votes, data["votes"]):
assert v2["grade"]["id"] == v1["grade_id"]
assert v2["candidate"]["id"] == v1["candidate_id"]
assert v2["election_id"] == election_id
token = data["token"]
# Now, we check that we need the righ token to read the votes
response = client.get(f"/votes/{token}WRONG")
assert response.status_code == 401, response.text
response = client.get(f"/votes/{token}")
assert response.status_code == 200, response.text
data = response.json()
for v1, v2 in zip(votes, data["votes"]):
assert v2["grade"]["id"] == v1["grade_id"]
assert v2["candidate"]["id"] == v1["candidate_id"]
assert v2["election_id"] == election_id
def test_cannot_create_vote_on_private_election():
assert False
def test_get_results():

@ -0,0 +1,56 @@
import pytest
from jose import jws
from ..auth import create_ballot_token, jws_verify, create_admin_token
from ..settings import settings
from .. import errors
def test_jws_verify_dict():
"""
Can verify a JWS token given as a dict
"""
payload = {"bar": "foo"}
token = jws.sign(payload, settings.secret, algorithm="HS256")
data = jws_verify(token)
assert payload == data
def test_jws_verify_secret():
"""
It must fail with a wrong key
"""
payload = {"bar": "foo"}
token = jws.sign(payload, settings.secret + "WRONG", algorithm="HS256")
with pytest.raises(errors.UnauthorizedError):
jws_verify(token)
def test_jws_verify_bytes():
"""
It must fail with bytes content
"""
payload = b"foo"
token = jws.sign(payload, settings.secret, algorithm="HS256")
with pytest.raises(errors.BadRequestError):
jws_verify(token)
def test_ballot_token():
"""
Can verify ballot tokens with MANY different tokens
"""
vote_ids = list(range(1000))
election_id = 0
token = create_ballot_token(vote_ids, election_id)
data = jws_verify(token)
assert data == {"votes": vote_ids, "election": election_id}
def test_admin_token():
"""
Can verify ballot tokens with MANY different tokens
"""
election_id = 0
token = create_admin_token(election_id)
data = jws_verify(token)
assert data == {"admin": True, "election": election_id}

@ -1 +1,7 @@
docker exec -it majority-judgment-api-python-mj_api-1 pytest
tmpfile=$(mktemp /tmp/mj-api.XXXX)
echo "SECRET=mysecrettoken" >> $tmpfile
echo "SQLITE=True" >> $tmpfile
docker run --env-file $tmpfile \
majority-judgment/api-python:latest \
pytest

@ -4,3 +4,4 @@ sqlalchemy==2.0.0b3
pydantic==1.10.2
psycopg2==2.9.5
git+https://github.com/MieuxVoter/majority-judgment-library-python
python-jose==3.3.0

Loading…
Cancel
Save