Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
Do a major refactoring of DB code (see #7)
  • Loading branch information
tdemin committed Jan 14, 2020
1 parent 1ed87ad commit f1abdff
Show file tree
Hide file tree
Showing 11 changed files with 309 additions and 305 deletions.
148 changes: 148 additions & 0 deletions project_amber/controllers/auth.py
@@ -0,0 +1,148 @@
from hashlib import sha256
from base64 import b64encode
from typing import List

from bcrypt import hashpw, gensalt, checkpw

from project_amber.const import MSG_USER_EXISTS
from project_amber.db import db
from project_amber.helpers import time
from project_amber.handlers import LoginUser
from project_amber.handlers.const import API_PASSWORD
from project_amber.errors import Unauthorized, NotFound, Conflict
from project_amber.logging import error
from project_amber.models.auth import User, Session


def prehash(password: str) -> bytes:
"""
Returns a "normalized" representation of the password that works
with bcrypt even when the password is longer than 72 chars.
"""
return b64encode(sha256(password.encode()).digest())


def gen_hashed_pw(password: str) -> bytes:
"""
Returns a bcrypt password hash with random salt.
"""
return hashpw(prehash(password), gensalt()).decode()


def gen_token() -> str:
"""
Returns a new freshly generated auth token.
"""
return sha256(gensalt() + bytes(str(time()).encode())).hexdigest()


class UserController:
user: LoginUser = None

def __init__(self, user: LoginUser):
self.user = user

def add_user(self, name: str, password: str) -> int:
"""
Creates a new user. Returns their ID on success.
"""
# does a user with this name already exist?
if not db.session.query(User).filter_by(name=name).one_or_none() is None:
raise Conflict(MSG_USER_EXISTS)
hashed_pw = gen_hashed_pw(password)
user = User(name=name, password=hashed_pw)
db.session.add(user)
db.session.commit()
return user.id

def update_user(self, **kwargs) -> int:
"""
Updates user data in the database. Returns their ID on success.
"""
user_record = db.session.query(User).filter_by(id=self.user.id).one()
for attribute in kwargs:
if attribute == API_PASSWORD:
user_record.password = gen_hashed_pw(kwargs[API_PASSWORD])
db.session.commit()
return self.user.id

def remove_user(self) -> int:
"""
Removes a user from the database. Returns their ID.
"""
user = db.session.query(User).filter_by(id=self.user.id).one_or_none()
try:
db.session.delete(user)
db.session.commit()
# pylint: disable=bare-except
except:
error("Failed to remove user %s!" % user.name)
return self.user.id

def verify_pw(self, uid: int, password: str) -> bool:
"""
Verifies user's password with bcrypt's checkpw(). Returns `True`, if
the passwords match, and False otherwise.
"""
user = db.session.query(User).filter_by(id=uid).one()
user_pass = user.password
if isinstance(user_pass, str):
user_pass = user_pass.encode()
return checkpw(prehash(password), user_pass)

def create_session(self, name: str, password: str, ip_addr: str) -> str:
"""
Creates a new user session. Returns an auth token.
"""
user = db.session.query(User).filter_by(name=name).one_or_none()
token: str
if user is None:
raise Unauthorized
if self.verify_pw(user.id, password):
token = gen_token()
session = Session(token=token, user=user.id, login_time=time(), address=ip_addr)
db.session.add(session)
db.session.commit()
else:
raise Unauthorized
return token

def remove_session(self) -> str:
"""
Logs the user out by removing their token from the database. Returns
the token on success.
"""
session = db.session.query(Session).filter_by(token=self.user.token,
user=self.user.id).one_or_none()
if session is None:
raise NotFound
db.session.delete(session)
db.session.commit()
return self.user.token

def remove_session_by_id(self, sid: int) -> int:
"""
Removes a user session by session ID. Returns the session ID on success.
"""
session = db.session.query(Session).filter_by(id=sid, user=self.user.id).one_or_none()
if session is None:
raise NotFound
db.session.delete(session)
db.session.commit()
return sid

def get_sessions(self) -> List[Session]:
"""
Returns a list of sessions of a user (class `Session`).
"""
sessions = db.session.query(Session).filter_by(user=self.user.id).all()
return sessions

def get_session(self, sid: int) -> Session:
"""
Returns a single `Session` by its ID.
"""
session = db.session.query(Session).filter_by(id=sid, user=self.user.id).one_or_none()
if session is None:
raise NotFound
return session
110 changes: 110 additions & 0 deletions project_amber/controllers/task.py
@@ -0,0 +1,110 @@
from typing import List

from project_amber.const import MSG_TASK_NOT_FOUND, MSG_TASK_DANGEROUS, \
MSG_TEXT_NOT_SPECIFIED
from project_amber.db import db
from project_amber.errors import NotFound, BadRequest
from project_amber.handlers import LoginUser
from project_amber.helpers import time
from project_amber.models.task import Task


class TaskController:
user: LoginUser = None

def __init__(self, user: LoginUser):
self.user = user

def add_task(self, data: dict) -> int:
"""
Creates a new task. Returns its ID.
"""
task = Task(self.user.id, data)
if task.text is None: raise BadRequest(MSG_TEXT_NOT_SPECIFIED)
if task.status is None: task.status = 0
parent_id = task.parent_id
if parent_id:
parent = db.session.query(Task).filter_by(id=parent_id,
owner=self.user.id).one_or_none()
if parent is None:
raise NotFound(MSG_TASK_NOT_FOUND)
task.add()
db.session.commit()
self.update_children(task.id)
# TODO: can we remove the second commit here?
db.session.commit()
return task.id

def get_task(self, task_id: int) -> Task:
"""
Returns an instance of `Task`, given the ID.
"""
task = db.session.query(Task).filter_by(id=task_id, owner=self.user.id).one_or_none()
if task is None:
raise NotFound(MSG_TASK_NOT_FOUND)
return task

def get_tasks(self, text: str = None) -> List[Task]:
"""
Returns a list containing tasks from a certain user. If the second
parameter is specified, this will return the tasks that have this text
in their description (`text in Task.text`).
"""
req = db.session.query(Task).filter_by(owner=self.user.id)
if text is None:
return req.all()
return req.filter(Task.text.ilike("%{0}%".format(text))).all()

def update_children(self, task_id: int):
"""
Recursively updates children lists for the children nodes of
a task subtree.
"""
task = self.get_task(task_id)
if task.parent_id:
parent = self.get_task(task.parent_id)
parent_list = parent.getParents()
parent_list.append(parent.id)
task.setParents(parent_list)
else:
task.setParents(list())
children = db.session.query(Task).filter_by(parent_id=task_id).all()
for child in children:
self.update_children(child.id)

def update_task(self, task_id: int, data: dict) -> int:
"""
Updates the task details. Returns its ID.
"""
task = self.get_task(task_id)
new_details = Task(self.user.id, data)
task.merge(new_details)
if not new_details.parent_id is None:
if new_details.parent_id == 0:
# promote task to the top level
task.parent_id = None
self.update_children(task.id)
else:
new_parent = self.get_task(new_details.parent_id)
if task.id in new_parent.getParents() or task.id == new_parent.id:
raise BadRequest(MSG_TASK_DANGEROUS)
task.parent_id = new_parent.id
self.update_children(task.id)
task.last_mod_time = time()
db.session.commit()
return task_id

def remove_task(self, task_id: int) -> List[int]:
"""
Removes a task, recursively removing its subtasks. Returns the list of
removed task IDs.
"""
removed = list()
children = db.session.query(Task).filter_by(parent_id=task_id).all()
for child in children:
removed.extend(self.remove_task(child.id))
task = self.get_task(task_id)
task.delete()
db.session.commit()
removed.append(task.id)
return removed
5 changes: 3 additions & 2 deletions project_amber/handlers/__init__.py
Expand Up @@ -16,11 +16,12 @@ class LoginUser:
and ID. The corresponding fields are `name` and `id`, respectively.
Also contains a token field.
"""
def __init__(self, name: str, uid: int, token: str, login_time: int):
def __init__(self, name: str, uid: int, token: str, login_time: int, remote_addr: str):
self.name = name
self.id = uid
self.token = token
self.login_time = login_time
self.remote_addr = remote_addr


def accepts_json(f):
Expand Down Expand Up @@ -54,7 +55,7 @@ def decorated_login_function(*args, **kwargs):
user = db.session.query(User).filter_by(id=user_s.user).one_or_none()
if user is None:
raise InternalServerError(MSG_USER_NOT_FOUND)
user_details = LoginUser(user.name, user.id, token, user_s.login_time)
user_details = LoginUser(user.name, user.id, token, user_s.login_time, request.remote_addr)
request.user = user_details
return f(*args, **kwargs)

Expand Down
15 changes: 10 additions & 5 deletions project_amber/handlers/auth.py
Expand Up @@ -6,7 +6,7 @@
from project_amber.errors import BadRequest
from project_amber.handlers import login_required, accepts_json
from project_amber.handlers.const import API_PASSWORD, API_USER, API_TOKEN
from project_amber.helpers.auth import removeSession, createSession
from project_amber.controllers.auth import UserController
from project_amber.logging import log

auth_handlers = Blueprint("auth_handlers", __name__)
Expand All @@ -31,9 +31,13 @@ def login():
```
Drops HTTP 401 on fail.
"""
if not API_USER in request.json or not API_PASSWORD in request.json:
username = request.json.get(API_USER)
password = request.json.get(API_PASSWORD)
if not username or not password:
raise BadRequest(MSG_MISSING_AUTH_INFO)
token = createSession(request.json[API_USER], request.json[API_PASSWORD])
uc = UserController(None)
token = uc.create_session(username, password, request.remote_addr)
log(f"User {username} logged in from {request.remote_addr}")
return dumps({API_TOKEN: token})


Expand All @@ -43,6 +47,7 @@ def logout():
"""
Logout handler. Returns HTTP 200 on success.
"""
removeSession(request.user.token)
log("User %s logged out" % request.user.name)
uc = UserController(request.user)
uc.remove_session()
log(f"User {request.user.name} logged out")
return EMPTY_RESP
33 changes: 11 additions & 22 deletions project_amber/handlers/session.py
Expand Up @@ -5,9 +5,8 @@
from project_amber.const import MATURE_SESSION, MSG_IMMATURE_SESSION, EMPTY_RESP
from project_amber.errors import Forbidden
from project_amber.handlers import login_required
from project_amber.handlers.const import API_ID, API_LOGIN_TIME, API_ADDRESS
from project_amber.helpers import time
from project_amber.helpers.auth import getSessions, getSession, removeSessionById
from project_amber.controllers.auth import UserController
from project_amber.logging import log

session_handlers = Blueprint("session_handlers", __name__)
Expand All @@ -34,14 +33,11 @@ def get_sessions():
]
```
"""
sessions = getSessions()
sessionList = []
uc = UserController(request.user)
sessions = uc.get_sessions()
sessionList = list()
for session in sessions:
sessionList.append({
API_ID: session.id,
API_LOGIN_TIME: session.login_time,
API_ADDRESS: session.address
})
sessionList.append(session.to_json())
return dumps(sessionList)


Expand All @@ -63,20 +59,13 @@ def session_by_id(session_id: int):
case here: if a client session is too recent, this will respond with
HTTP 403.
"""
uc = UserController(request.user)
if request.method == "GET":
session = getSession(session_id)
return dumps({
API_ID: session.id,
API_LOGIN_TIME: session.login_time,
API_ADDRESS: session.address
})
session = uc.get_session(session_id)
return dumps(session.to_json())
if request.method == "DELETE":
if (time() - request.user.login_time) < MATURE_SESSION:
if (time() - uc.user.login_time) < MATURE_SESSION:
raise Forbidden(MSG_IMMATURE_SESSION)
removeSessionById(session_id)
log(
"User {0} deleted session {1}".format(
request.user.name, session_id
)
)
uc.remove_session_by_id(session_id)
log(f"User {uc.user.name} deleted session {session_id}")
return EMPTY_RESP

0 comments on commit f1abdff

Please sign in to comment.