This repository has been archived by the owner on Jan 20, 2024. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Do a major refactoring of DB code (see #7)
- Loading branch information
Showing
11 changed files
with
309 additions
and
305 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from hashlib import sha256 | ||
from base64 import b64encode | ||
from typing import List | ||
|
||
from bcrypt import hashpw, gensalt, checkpw | ||
|
||
from project_amber.const import MSG_USER_EXISTS | ||
from project_amber.db import db | ||
from project_amber.helpers import time | ||
from project_amber.handlers import LoginUser | ||
from project_amber.handlers.const import API_PASSWORD | ||
from project_amber.errors import Unauthorized, NotFound, Conflict | ||
from project_amber.logging import error | ||
from project_amber.models.auth import User, Session | ||
|
||
|
||
def prehash(password: str) -> bytes: | ||
""" | ||
Returns a "normalized" representation of the password that works | ||
with bcrypt even when the password is longer than 72 chars. | ||
""" | ||
return b64encode(sha256(password.encode()).digest()) | ||
|
||
|
||
def gen_hashed_pw(password: str) -> bytes: | ||
""" | ||
Returns a bcrypt password hash with random salt. | ||
""" | ||
return hashpw(prehash(password), gensalt()).decode() | ||
|
||
|
||
def gen_token() -> str: | ||
""" | ||
Returns a new freshly generated auth token. | ||
""" | ||
return sha256(gensalt() + bytes(str(time()).encode())).hexdigest() | ||
|
||
|
||
class UserController: | ||
user: LoginUser = None | ||
|
||
def __init__(self, user: LoginUser): | ||
self.user = user | ||
|
||
def add_user(self, name: str, password: str) -> int: | ||
""" | ||
Creates a new user. Returns their ID on success. | ||
""" | ||
# does a user with this name already exist? | ||
if not db.session.query(User).filter_by(name=name).one_or_none() is None: | ||
raise Conflict(MSG_USER_EXISTS) | ||
hashed_pw = gen_hashed_pw(password) | ||
user = User(name=name, password=hashed_pw) | ||
db.session.add(user) | ||
db.session.commit() | ||
return user.id | ||
|
||
def update_user(self, **kwargs) -> int: | ||
""" | ||
Updates user data in the database. Returns their ID on success. | ||
""" | ||
user_record = db.session.query(User).filter_by(id=self.user.id).one() | ||
for attribute in kwargs: | ||
if attribute == API_PASSWORD: | ||
user_record.password = gen_hashed_pw(kwargs[API_PASSWORD]) | ||
db.session.commit() | ||
return self.user.id | ||
|
||
def remove_user(self) -> int: | ||
""" | ||
Removes a user from the database. Returns their ID. | ||
""" | ||
user = db.session.query(User).filter_by(id=self.user.id).one_or_none() | ||
try: | ||
db.session.delete(user) | ||
db.session.commit() | ||
# pylint: disable=bare-except | ||
except: | ||
error("Failed to remove user %s!" % user.name) | ||
return self.user.id | ||
|
||
def verify_pw(self, uid: int, password: str) -> bool: | ||
""" | ||
Verifies user's password with bcrypt's checkpw(). Returns `True`, if | ||
the passwords match, and False otherwise. | ||
""" | ||
user = db.session.query(User).filter_by(id=uid).one() | ||
user_pass = user.password | ||
if isinstance(user_pass, str): | ||
user_pass = user_pass.encode() | ||
return checkpw(prehash(password), user_pass) | ||
|
||
def create_session(self, name: str, password: str, ip_addr: str) -> str: | ||
""" | ||
Creates a new user session. Returns an auth token. | ||
""" | ||
user = db.session.query(User).filter_by(name=name).one_or_none() | ||
token: str | ||
if user is None: | ||
raise Unauthorized | ||
if self.verify_pw(user.id, password): | ||
token = gen_token() | ||
session = Session(token=token, user=user.id, login_time=time(), address=ip_addr) | ||
db.session.add(session) | ||
db.session.commit() | ||
else: | ||
raise Unauthorized | ||
return token | ||
|
||
def remove_session(self) -> str: | ||
""" | ||
Logs the user out by removing their token from the database. Returns | ||
the token on success. | ||
""" | ||
session = db.session.query(Session).filter_by(token=self.user.token, | ||
user=self.user.id).one_or_none() | ||
if session is None: | ||
raise NotFound | ||
db.session.delete(session) | ||
db.session.commit() | ||
return self.user.token | ||
|
||
def remove_session_by_id(self, sid: int) -> int: | ||
""" | ||
Removes a user session by session ID. Returns the session ID on success. | ||
""" | ||
session = db.session.query(Session).filter_by(id=sid, user=self.user.id).one_or_none() | ||
if session is None: | ||
raise NotFound | ||
db.session.delete(session) | ||
db.session.commit() | ||
return sid | ||
|
||
def get_sessions(self) -> List[Session]: | ||
""" | ||
Returns a list of sessions of a user (class `Session`). | ||
""" | ||
sessions = db.session.query(Session).filter_by(user=self.user.id).all() | ||
return sessions | ||
|
||
def get_session(self, sid: int) -> Session: | ||
""" | ||
Returns a single `Session` by its ID. | ||
""" | ||
session = db.session.query(Session).filter_by(id=sid, user=self.user.id).one_or_none() | ||
if session is None: | ||
raise NotFound | ||
return session |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from typing import List | ||
|
||
from project_amber.const import MSG_TASK_NOT_FOUND, MSG_TASK_DANGEROUS, \ | ||
MSG_TEXT_NOT_SPECIFIED | ||
from project_amber.db import db | ||
from project_amber.errors import NotFound, BadRequest | ||
from project_amber.handlers import LoginUser | ||
from project_amber.helpers import time | ||
from project_amber.models.task import Task | ||
|
||
|
||
class TaskController: | ||
user: LoginUser = None | ||
|
||
def __init__(self, user: LoginUser): | ||
self.user = user | ||
|
||
def add_task(self, data: dict) -> int: | ||
""" | ||
Creates a new task. Returns its ID. | ||
""" | ||
task = Task(self.user.id, data) | ||
if task.text is None: raise BadRequest(MSG_TEXT_NOT_SPECIFIED) | ||
if task.status is None: task.status = 0 | ||
parent_id = task.parent_id | ||
if parent_id: | ||
parent = db.session.query(Task).filter_by(id=parent_id, | ||
owner=self.user.id).one_or_none() | ||
if parent is None: | ||
raise NotFound(MSG_TASK_NOT_FOUND) | ||
task.add() | ||
db.session.commit() | ||
self.update_children(task.id) | ||
# TODO: can we remove the second commit here? | ||
db.session.commit() | ||
return task.id | ||
|
||
def get_task(self, task_id: int) -> Task: | ||
""" | ||
Returns an instance of `Task`, given the ID. | ||
""" | ||
task = db.session.query(Task).filter_by(id=task_id, owner=self.user.id).one_or_none() | ||
if task is None: | ||
raise NotFound(MSG_TASK_NOT_FOUND) | ||
return task | ||
|
||
def get_tasks(self, text: str = None) -> List[Task]: | ||
""" | ||
Returns a list containing tasks from a certain user. If the second | ||
parameter is specified, this will return the tasks that have this text | ||
in their description (`text in Task.text`). | ||
""" | ||
req = db.session.query(Task).filter_by(owner=self.user.id) | ||
if text is None: | ||
return req.all() | ||
return req.filter(Task.text.ilike("%{0}%".format(text))).all() | ||
|
||
def update_children(self, task_id: int): | ||
""" | ||
Recursively updates children lists for the children nodes of | ||
a task subtree. | ||
""" | ||
task = self.get_task(task_id) | ||
if task.parent_id: | ||
parent = self.get_task(task.parent_id) | ||
parent_list = parent.getParents() | ||
parent_list.append(parent.id) | ||
task.setParents(parent_list) | ||
else: | ||
task.setParents(list()) | ||
children = db.session.query(Task).filter_by(parent_id=task_id).all() | ||
for child in children: | ||
self.update_children(child.id) | ||
|
||
def update_task(self, task_id: int, data: dict) -> int: | ||
""" | ||
Updates the task details. Returns its ID. | ||
""" | ||
task = self.get_task(task_id) | ||
new_details = Task(self.user.id, data) | ||
task.merge(new_details) | ||
if not new_details.parent_id is None: | ||
if new_details.parent_id == 0: | ||
# promote task to the top level | ||
task.parent_id = None | ||
self.update_children(task.id) | ||
else: | ||
new_parent = self.get_task(new_details.parent_id) | ||
if task.id in new_parent.getParents() or task.id == new_parent.id: | ||
raise BadRequest(MSG_TASK_DANGEROUS) | ||
task.parent_id = new_parent.id | ||
self.update_children(task.id) | ||
task.last_mod_time = time() | ||
db.session.commit() | ||
return task_id | ||
|
||
def remove_task(self, task_id: int) -> List[int]: | ||
""" | ||
Removes a task, recursively removing its subtasks. Returns the list of | ||
removed task IDs. | ||
""" | ||
removed = list() | ||
children = db.session.query(Task).filter_by(parent_id=task_id).all() | ||
for child in children: | ||
removed.extend(self.remove_task(child.id)) | ||
task = self.get_task(task_id) | ||
task.delete() | ||
db.session.commit() | ||
removed.append(task.id) | ||
return removed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.