Dockerised. Restructured to remove circular import

Moved most of app definitions out of guard function to use wsgi
Updated configuration files and referencing of .env values.
Local version needs dotenv or exporting of env variables.
Dockerised version works fine without load_dotenv.
Ready to test now!
This commit is contained in:
Vivek Santayana 2021-12-05 03:49:31 +00:00
parent f9d16b3608
commit 281575bbf7
9 changed files with 82 additions and 51 deletions

View File

@ -2,4 +2,4 @@ FROM python:3.10-slim
WORKDIR /ref-test WORKDIR /ref-test
COPY . . COPY . .
RUN pip install --upgrade pip && pip install -r requirements.txt RUN pip install --upgrade pip && pip install -r requirements.txt
CMD [ "gunicorn", "-b", "0.0.0.0:5000", "-w", "8", "main:app" ] CMD [ "gunicorn", "-b", "0.0.0.0:5000", "-w", "5", "wsgi:app" ]

View File

@ -6,8 +6,6 @@ from uuid import uuid4
from common.security.database import decrypt_find_one, encrypted_update from common.security.database import decrypt_find_one, encrypted_update
from werkzeug.security import check_password_hash from werkzeug.security import check_password_hash
from main import db
from .views import admin_account_required, disable_on_registration, login_required, disable_if_logged_in, get_id_from_cookie from .views import admin_account_required, disable_on_registration, login_required, disable_if_logged_in, get_id_from_cookie
auth = Blueprint( auth = Blueprint(
@ -22,6 +20,7 @@ auth = Blueprint(
@login_required @login_required
def account(): def account():
from .models.forms import UpdateAccountForm from .models.forms import UpdateAccountForm
from main import db
form = UpdateAccountForm() form = UpdateAccountForm()
_id = get_id_from_cookie() _id = get_id_from_cookie()
user = decrypt_find_one(db.users, {'_id': _id}) user = decrypt_find_one(db.users, {'_id': _id})
@ -112,6 +111,7 @@ def reset():
@admin_account_required @admin_account_required
@disable_if_logged_in @disable_if_logged_in
def reset_gateway(token1,token2): def reset_gateway(token1,token2):
from main import db
user = decrypt_find_one( db.users, {'reset_token' : token1} ) user = decrypt_find_one( db.users, {'reset_token' : token1} )
if not user: if not user:
return redirect(url_for('admin_auth.login')) return redirect(url_for('admin_auth.login'))

View File

@ -6,10 +6,10 @@ import secrets
import os import os
from json import dump, loads from json import dump, loads
from main import app, db
from common.security import encrypt from common.security import encrypt
class Test: class Test:
def __init__(self, _id=None, start_date=None, expiry_date=None, time_limit=None, creator=None, dataset=None): def __init__(self, _id=None, start_date=None, expiry_date=None, time_limit=None, creator=None, dataset=None):
self._id = _id self._id = _id
self.start_date = start_date self.start_date = start_date
@ -19,6 +19,7 @@ class Test:
self.dataset = dataset self.dataset = dataset
def create(self): def create(self):
from main import app, db
test = { test = {
'_id': self._id, '_id': self._id,
'date_created': datetime.today(), 'date_created': datetime.today(),
@ -41,6 +42,7 @@ class Test:
return jsonify({'error': f'Could not create exam. An error occurred.'}), 400 return jsonify({'error': f'Could not create exam. An error occurred.'}), 400
def add_time_adjustment(self, time_adjustment): def add_time_adjustment(self, time_adjustment):
from main import db
user_code = secrets.token_hex(3).upper() user_code = secrets.token_hex(3).upper()
adjustment = { adjustment = {
user_code: time_adjustment user_code: time_adjustment
@ -51,6 +53,7 @@ class Test:
return jsonify({'error': 'Failed to add the time adjustment. An error occurred.'}), 400 return jsonify({'error': 'Failed to add the time adjustment. An error occurred.'}), 400
def remove_time_adjustment(self, user_code): def remove_time_adjustment(self, user_code):
from main import db
if db.tests.find_one_and_update({'_id': self._id}, {'$unset': {f'time_adjustments.{user_code}': {}}}): if db.tests.find_one_and_update({'_id': self._id}, {'$unset': {f'time_adjustments.{user_code}': {}}}):
message = 'Time adjustment has been deleted.' message = 'Time adjustment has been deleted.'
flash(message, 'success') flash(message, 'success')
@ -64,6 +67,7 @@ class Test:
return test_code.replace('', '') return test_code.replace('', '')
def delete(self): def delete(self):
from main import app, db
test = db.tests.find_one({'_id': self._id}) test = db.tests.find_one({'_id': self._id})
if 'entries' in test: if 'entries' in test:
if test['entries']: if test['entries']:
@ -83,6 +87,7 @@ class Test:
return jsonify({'error': f'Could not create exam. An error occurred.'}), 400 return jsonify({'error': f'Could not create exam. An error occurred.'}), 400
def update(self): def update(self):
from main import db
test = {} test = {}
updated = [] updated = []
if not self.start_date == '' and self.start_date is not None: if not self.start_date == '' and self.start_date is not None:

View File

@ -10,10 +10,8 @@ from werkzeug.security import check_password_hash
from common.security.database import decrypt_find, decrypt_find_one from common.security.database import decrypt_find, decrypt_find_one
from .models.users import User from .models.users import User
from flask_mail import Message from flask_mail import Message
from main import app, db
from uuid import uuid4 from uuid import uuid4
import secrets import secrets
from main import mail
from datetime import datetime, date from datetime import datetime, date
from .models.tests import Test from .models.tests import Test
from common.data_tools import get_default_dataset, get_time_options, available_datasets, get_datasets from common.data_tools import get_default_dataset, get_time_options, available_datasets, get_datasets
@ -28,6 +26,8 @@ views = Blueprint(
def admin_account_required(function): def admin_account_required(function):
@wraps(function) @wraps(function)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
from main import db
from main import db
if not db.users.find_one({}): if not db.users.find_one({}):
flash('No administrator accounts have been registered. Please register an administrator account.', 'alert') flash('No administrator accounts have been registered. Please register an administrator account.', 'alert')
return redirect(url_for('admin_auth.register')) return redirect(url_for('admin_auth.register'))
@ -37,6 +37,7 @@ def admin_account_required(function):
def disable_on_registration(function): def disable_on_registration(function):
@wraps(function) @wraps(function)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
from main import db
if db.users.find_one({}): if db.users.find_one({}):
return abort(404) return abort(404)
return function(*args, **kwargs) return function(*args, **kwargs)
@ -46,6 +47,7 @@ def get_id_from_cookie():
return request.cookies.get('_id') return request.cookies.get('_id')
def get_user_from_db(_id): def get_user_from_db(_id):
from main import db
return db.users.find_one({'_id': _id}) return db.users.find_one({'_id': _id})
def check_login(): def check_login():
@ -76,6 +78,7 @@ def disable_if_logged_in(function):
@admin_account_required @admin_account_required
@login_required @login_required
def home(): def home():
from main import db
tests = db.tests.find() tests = db.tests.find()
results = decrypt_find(db.entries, {}) results = decrypt_find(db.entries, {})
current_tests = [ test for test in tests if test['expiry_date'].date() >= date.today() and test['start_date'].date() <= date.today() ] current_tests = [ test for test in tests if test['expiry_date'].date() >= date.today() and test['start_date'].date() <= date.today() ]
@ -92,6 +95,7 @@ def home():
@admin_account_required @admin_account_required
@login_required @login_required
def settings(): def settings():
from main import db
users = decrypt_find(db.users, {}) users = decrypt_find(db.users, {})
users.sort(key= lambda x: x['username']) users.sort(key= lambda x: x['username'])
datasets = get_datasets() datasets = get_datasets()
@ -101,6 +105,7 @@ def settings():
@admin_account_required @admin_account_required
@login_required @login_required
def users(): def users():
from main import db, mail
from .models.forms import CreateUserForm from .models.forms import CreateUserForm
form = CreateUserForm() form = CreateUserForm()
if request.method == 'GET': if request.method == 'GET':
@ -148,6 +153,7 @@ def users():
@admin_account_required @admin_account_required
@login_required @login_required
def delete_user(_id:str): def delete_user(_id:str):
from main import db, mail
if _id == get_id_from_cookie(): if _id == get_id_from_cookie():
flash('Cannot delete your own user account.', 'error') flash('Cannot delete your own user account.', 'error')
return redirect(url_for('admin_views.users')) return redirect(url_for('admin_views.users'))
@ -197,6 +203,7 @@ def delete_user(_id:str):
@admin_account_required @admin_account_required
@login_required @login_required
def update_user(_id:str): def update_user(_id:str):
from main import db, mail
if _id == get_id_from_cookie(): if _id == get_id_from_cookie():
flash('Cannot delete your own user account.', 'error') flash('Cannot delete your own user account.', 'error')
return redirect(url_for('admin_views.users')) return redirect(url_for('admin_views.users'))
@ -279,6 +286,7 @@ def questions():
@admin_account_required @admin_account_required
@login_required @login_required
def delete_questions(): def delete_questions():
from main import db, app
filename = request.get_json()['filename'] filename = request.get_json()['filename']
data_files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json')) data_files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json'))
if any(filename in file for file in data_files): if any(filename in file for file in data_files):
@ -301,6 +309,7 @@ def delete_questions():
@admin_account_required @admin_account_required
@login_required @login_required
def make_default_questions(): def make_default_questions():
from main import app
filename = request.get_json()['filename'] filename = request.get_json()['filename']
data_files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json')) data_files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json'))
default_file_path = os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt') default_file_path = os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt')
@ -320,6 +329,7 @@ def make_default_questions():
@admin_account_required @admin_account_required
@login_required @login_required
def tests(filter=''): def tests(filter=''):
from main import db
if not available_datasets(): if not available_datasets():
flash('There are no available question datasets. Please upload a question dataset in order to set up an exam.', 'error') flash('There are no available question datasets. Please upload a question dataset in order to set up an exam.', 'error')
return redirect(url_for('admin_views.questions')) return redirect(url_for('admin_views.questions'))
@ -359,6 +369,7 @@ def tests(filter=''):
@admin_account_required @admin_account_required
@login_required @login_required
def create_test(): def create_test():
from main import db
from .models.forms import CreateTest from .models.forms import CreateTest
form = CreateTest() form = CreateTest()
form.dataset.choices = available_datasets() form.dataset.choices = available_datasets()
@ -398,6 +409,7 @@ def create_test():
@admin_account_required @admin_account_required
@login_required @login_required
def delete_test(): def delete_test():
from main import db
_id = request.get_json()['_id'] _id = request.get_json()['_id']
if db.tests.find_one({'_id': _id}): if db.tests.find_one({'_id': _id}):
return Test(_id = _id).delete() return Test(_id = _id).delete()
@ -407,6 +419,7 @@ def delete_test():
@admin_account_required @admin_account_required
@login_required @login_required
def view_test(_id): def view_test(_id):
from main import db
from .models.forms import AddTimeAdjustment from .models.forms import AddTimeAdjustment
form = AddTimeAdjustment() form = AddTimeAdjustment()
test = decrypt_find_one(db.tests, {'_id': _id}) test = decrypt_find_one(db.tests, {'_id': _id})
@ -431,6 +444,7 @@ def delete_adjustment(_id):
@admin_account_required @admin_account_required
@login_required @login_required
def view_entries(): def view_entries():
from main import db
entries = decrypt_find(db.entries, {}) entries = decrypt_find(db.entries, {})
return render_template('/admin/results.html', entries = entries) return render_template('/admin/results.html', entries = entries)
@ -438,6 +452,7 @@ def view_entries():
@admin_account_required @admin_account_required
@login_required @login_required
def view_entry(_id=''): def view_entry(_id=''):
from main import db
entry = decrypt_find_one(db.entries, {'_id': _id}) entry = decrypt_find_one(db.entries, {'_id': _id})
if request.method == 'GET': if request.method == 'GET':
if not entry: if not entry:
@ -468,6 +483,7 @@ def view_entry(_id=''):
@admin_account_required @admin_account_required
@login_required @login_required
def generate_certificate(): def generate_certificate():
from main import db
_id = request.get_json()['_id'] _id = request.get_json()['_id']
entry = decrypt_find_one(db.entries, {'_id': _id}) entry = decrypt_find_one(db.entries, {'_id': _id})
if not entry: if not entry:

View File

@ -6,35 +6,38 @@ from glob import glob
from random import shuffle from random import shuffle
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
from main import app, db
from .security.database import decrypt_find_one from .security.database import decrypt_find_one
def check_data_folder_exists(): def check_data_folder_exists():
from main import app
if not os.path.exists(app.config['DATA_FILE_DIRECTORY']): if not os.path.exists(app.config['DATA_FILE_DIRECTORY']):
pathlib.Path(app.config['DATA_FILE_DIRECTORY']).mkdir(parents='True', exist_ok='True') pathlib.Path(app.config['DATA_FILE_DIRECTORY']).mkdir(parents='True', exist_ok='True')
def check_default_indicator(): def check_default_indicator():
from main import app
if not os.path.isfile(os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt')): if not os.path.isfile(os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt')):
open(os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt'),'w').close() open(os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt'),'w').close()
def get_default_dataset(): def get_default_dataset():
check_default_indicator() check_default_indicator()
from main import app
default_file_path = os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt') default_file_path = os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt')
with open(default_file_path, 'r') as default_file: with open(default_file_path, 'r') as default_file:
default = default_file.read() default = default_file.read()
return default return default
def available_datasets(): def available_datasets():
files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json')) from main import app
default = get_default_dataset() files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json'))
output = [] default = get_default_dataset()
for file in files: output = []
filename = file.rsplit('/')[-1] for file in files:
label = f'{filename[:-5]} (Default)' if filename == default else filename[:-5] filename = file.rsplit('/')[-1]
element = (filename, label) label = f'{filename[:-5]} (Default)' if filename == default else filename[:-5]
output.append(element) element = (filename, label)
output.reverse() output.append(element)
return output output.reverse()
return output
def check_json_format(file): def check_json_format(file):
if not '.' in file.filename: if not '.' in file.filename:
@ -58,6 +61,7 @@ def validate_json_contents(file):
def store_data_file(file, default:bool=None): def store_data_file(file, default:bool=None):
from admin.views import get_id_from_cookie from admin.views import get_id_from_cookie
from main import app
check_default_indicator() check_default_indicator()
timestamp = datetime.utcnow() timestamp = datetime.utcnow()
filename = '.'.join([timestamp.strftime('%Y%m%d%H%M%S'),'json']) filename = '.'.join([timestamp.strftime('%Y%m%d%H%M%S'),'json'])
@ -201,6 +205,7 @@ def get_time_options():
return time_options return time_options
def get_datasets(): def get_datasets():
from main import app, db
files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json')) files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json'))
data = [] data = []
if files: if files:

View File

@ -5,8 +5,6 @@ class Config(object):
TESTING = False TESTING = False
SECRET_KEY = os.getenv('SECRET_KEY') SECRET_KEY = os.getenv('SECRET_KEY')
from dotenv import load_dotenv
load_dotenv()
MONGO_INITDB_DATABASE = os.getenv('MONGO_INITDB_DATABASE') MONGO_INITDB_DATABASE = os.getenv('MONGO_INITDB_DATABASE')
from urllib import parse from urllib import parse
MONGO_URI = f'mongodb://{os.getenv("MONGO_INITDB_USERNAME")}:{parse.quote_plus(os.getenv("MONGO_INITDB_PASSWORD"))}@{os.getenv("MONGO_DB_HOST_ALIAS")}:{os.getenv("MONGO_PORT")}/' MONGO_URI = f'mongodb://{os.getenv("MONGO_INITDB_USERNAME")}:{parse.quote_plus(os.getenv("MONGO_INITDB_PASSWORD"))}@{os.getenv("MONGO_DB_HOST_ALIAS")}:{os.getenv("MONGO_PORT")}/'
@ -32,21 +30,21 @@ class ProductionConfig(Config):
pass pass
class DevelopmentConfig(Config): class DevelopmentConfig(Config):
from dotenv import load_dotenv
load_dotenv()
DEBUG = True DEBUG = True
SESSION_COOKIE_SECURE = False SESSION_COOKIE_SECURE = False
MONGO_INITDB_DATABASE = os.getenv('MONGO_INITDB_DATABASE') MONGO_INITDB_DATABASE = os.getenv('MONGO_INITDB_DATABASE')
from urllib import parse from urllib import parse
MONGO_URI = f'mongodb://{os.getenv("MONGO_INITDB_USERNAME")}:{parse.quote_plus(os.getenv("MONGO_INITDB_PASSWORD"))}@localhost:{os.getenv("MONGO_PORT")}/' MONGO_URI = f'mongodb://{os.getenv("MONGO_INITDB_USERNAME")}:{parse.quote_plus(os.getenv("MONGO_INITDB_PASSWORD"))}@localhost:{os.getenv("MONGO_PORT")}/'
APP_HOST = '127.0.0.1' APP_HOST = '127.0.0.1'
MAIL_SERVER = 'localhost'
MAIL_DEBUG = True MAIL_DEBUG = True
MAIL_SUPPRESS_SEND = False MAIL_SUPPRESS_SEND = False
class TestingConfig(Config): class TestingConfig(DevelopmentConfig):
from dotenv import load_dotenv
load_dotenv()
TESTING = True TESTING = True
SESSION_COOKIE_SECURE = False SESSION_COOKIE_SECURE = False
MAIL_SERVER = os.getenv("MAIL_SERVER")
MAIL_DEBUG = True MAIL_DEBUG = True
MAIL_SUPPRESS_SEND = False MAIL_SUPPRESS_SEND = False
from urllib import parse
MONGO_URI = f'mongodb://{os.getenv("MONGO_INITDB_USERNAME")}:{parse.quote_plus(os.getenv("MONGO_INITDB_PASSWORD"))}@{os.getenv("MONGO_DB_HOST_ALIAS")}:{os.getenv("MONGO_PORT")}/'

View File

@ -10,32 +10,11 @@ from flask_wtf.csrf import CSRFProtect, CSRFError
from flask_mail import Mail from flask_mail import Mail
from common.security import check_keyfile_exists, generate_keyfile from common.security import check_keyfile_exists, generate_keyfile
import config
app = Flask(__name__) def create_app():
app.config.from_object('config.DevelopmentConfig') app = Flask(__name__)
app.config.from_object(config.TestingConfig())
Bootstrap(app)
csrf = CSRFProtect(app)
@app.errorhandler(CSRFError)
def csrf_error_handler(error):
return jsonify({ 'error': 'Could not validate a secure connection.'} ), 400
try:
mongo = MongoClient(app.config['MONGO_URI'])
db = mongo[app.config['MONGO_INITDB_DATABASE']]
except ConnectionFailure as error:
print(error)
try:
mail = Mail(app)
except Exception as error:
print(error)
if __name__ == '__main__':
if not check_keyfile_exists():
generate_keyfile()
from common.blueprints import cookie_consent from common.blueprints import cookie_consent
@ -81,4 +60,22 @@ if __name__ == '__main__':
def _404_handler(e): def _404_handler(e):
return render_template('/quiz/404.html'), 404 return render_template('/quiz/404.html'), 404
@app.errorhandler(CSRFError)
def csrf_error_handler(error):
return jsonify({ 'error': 'Could not validate a secure connection.'} ), 400
if not check_keyfile_exists():
generate_keyfile()
Bootstrap(app)
csrf = CSRFProtect(app)
return app
app = create_app()
mongo = MongoClient(app.config['MONGO_URI'])
db = mongo[app.config['MONGO_INITDB_DATABASE']]
mail = Mail(app)
if __name__ == '__main__':
app.run(host=app.config['APP_HOST']) app.run(host=app.config['APP_HOST'])

View File

@ -8,7 +8,6 @@ from flask_mail import Message
from pymongo.collection import ReturnDocument from pymongo.collection import ReturnDocument
from main import app, db, mail
from common.security import encrypt from common.security import encrypt
from common.data_tools import generate_questions, evaluate_answers from common.data_tools import generate_questions, evaluate_answers
from common.security.database import decrypt_find_one from common.security.database import decrypt_find_one
@ -24,6 +23,7 @@ views = Blueprint(
@views.route('/') @views.route('/')
@views.route('/home/') @views.route('/home/')
def home(): def home():
from main import db
_id = session.get('_id') _id = session.get('_id')
if _id and db.entries.find_one({'_id': _id}): if _id and db.entries.find_one({'_id': _id}):
return redirect(url_for('quiz_views.start_quiz')) return redirect(url_for('quiz_views.start_quiz'))
@ -31,6 +31,7 @@ def home():
@views.route('/instructions/') @views.route('/instructions/')
def instructions(): def instructions():
from main import db
_id = session.get('_id') _id = session.get('_id')
if _id and db.entries.find_one({'_id': _id}): if _id and db.entries.find_one({'_id': _id}):
return redirect(url_for('quiz_views.start_quiz')) return redirect(url_for('quiz_views.start_quiz'))
@ -38,6 +39,7 @@ def instructions():
@views.route('/start/', methods = ['GET', 'POST']) @views.route('/start/', methods = ['GET', 'POST'])
def start(): def start():
from main import db
from .forms import StartQuiz from .forms import StartQuiz
form = StartQuiz() form = StartQuiz()
if request.method == 'GET': if request.method == 'GET':
@ -85,6 +87,7 @@ def start():
@views.route('/api/questions/', methods=['POST']) @views.route('/api/questions/', methods=['POST'])
def fetch_questions(): def fetch_questions():
from main import app, db
_id = request.get_json()['_id'] _id = request.get_json()['_id']
entry = db.entries.find_one({'_id': _id}) entry = db.entries.find_one({'_id': _id})
if not entry: if not entry:
@ -124,6 +127,7 @@ def fetch_questions():
@views.route('/test/') @views.route('/test/')
def start_quiz(): def start_quiz():
from main import db
_id = session.get('_id') _id = session.get('_id')
if not _id or not db.entries.find_one({'_id': _id}): if not _id or not db.entries.find_one({'_id': _id}):
flash('Your log in was not recognised. Please sign in to the quiz again.', 'error') flash('Your log in was not recognised. Please sign in to the quiz again.', 'error')
@ -132,6 +136,7 @@ def start_quiz():
@views.route('/api/submit/', methods=['POST']) @views.route('/api/submit/', methods=['POST'])
def submit_quiz(): def submit_quiz():
from main import app, db
_id = request.get_json()['_id'] _id = request.get_json()['_id']
answers = request.get_json()['answers'] answers = request.get_json()['answers']
entry = db.entries.find_one({'_id': _id}) entry = db.entries.find_one({'_id': _id})
@ -161,6 +166,7 @@ def submit_quiz():
@views.route('/result/') @views.route('/result/')
def result(): def result():
from main import db, mail
_id = session.get('_id') _id = session.get('_id')
entry = decrypt_find_one(db.entries, {'_id': _id}) entry = decrypt_find_one(db.entries, {'_id': _id})
if not entry: if not entry:

4
ref-test/wsgi.py Normal file
View File

@ -0,0 +1,4 @@
from main import app
if __name__ == '__main__':
app.run()