diff --git a/ref-test/Dockerfile b/ref-test/Dockerfile index dd67898..24850e9 100644 --- a/ref-test/Dockerfile +++ b/ref-test/Dockerfile @@ -2,4 +2,4 @@ FROM python:3.10-slim WORKDIR /ref-test COPY . . RUN pip install --upgrade pip && pip install -r requirements.txt -CMD [ "gunicorn", "-b", "0.0.0.0:5000", "-w", "8", "main:app" ] \ No newline at end of file +CMD [ "gunicorn", "-b", "0.0.0.0:5000", "-w", "5", "wsgi:app" ] \ No newline at end of file diff --git a/ref-test/admin/auth.py b/ref-test/admin/auth.py index 2e38c41..c614175 100644 --- a/ref-test/admin/auth.py +++ b/ref-test/admin/auth.py @@ -6,8 +6,6 @@ from uuid import uuid4 from common.security.database import decrypt_find_one, encrypted_update from werkzeug.security import check_password_hash -from main import db - from .views import admin_account_required, disable_on_registration, login_required, disable_if_logged_in, get_id_from_cookie auth = Blueprint( @@ -22,6 +20,7 @@ auth = Blueprint( @login_required def account(): from .models.forms import UpdateAccountForm + from main import db form = UpdateAccountForm() _id = get_id_from_cookie() user = decrypt_find_one(db.users, {'_id': _id}) @@ -112,6 +111,7 @@ def reset(): @admin_account_required @disable_if_logged_in def reset_gateway(token1,token2): + from main import db user = decrypt_find_one( db.users, {'reset_token' : token1} ) if not user: return redirect(url_for('admin_auth.login')) diff --git a/ref-test/admin/models/tests.py b/ref-test/admin/models/tests.py index edaaa70..123bc14 100644 --- a/ref-test/admin/models/tests.py +++ b/ref-test/admin/models/tests.py @@ -6,10 +6,10 @@ import secrets import os from json import dump, loads -from main import app, db from common.security import encrypt class Test: + def __init__(self, _id=None, start_date=None, expiry_date=None, time_limit=None, creator=None, dataset=None): self._id = _id self.start_date = start_date @@ -19,6 +19,7 @@ class Test: self.dataset = dataset def create(self): + from main import app, db test = { '_id': self._id, 'date_created': datetime.today(), @@ -41,6 +42,7 @@ class Test: return jsonify({'error': f'Could not create exam. An error occurred.'}), 400 def add_time_adjustment(self, time_adjustment): + from main import db user_code = secrets.token_hex(3).upper() adjustment = { user_code: time_adjustment @@ -51,6 +53,7 @@ class Test: return jsonify({'error': 'Failed to add the time adjustment. An error occurred.'}), 400 def remove_time_adjustment(self, user_code): + from main import db if db.tests.find_one_and_update({'_id': self._id}, {'$unset': {f'time_adjustments.{user_code}': {}}}): message = 'Time adjustment has been deleted.' flash(message, 'success') @@ -64,6 +67,7 @@ class Test: return test_code.replace('—', '') def delete(self): + from main import app, db test = db.tests.find_one({'_id': self._id}) if 'entries' in test: if test['entries']: @@ -83,6 +87,7 @@ class Test: return jsonify({'error': f'Could not create exam. An error occurred.'}), 400 def update(self): + from main import db test = {} updated = [] if not self.start_date == '' and self.start_date is not None: diff --git a/ref-test/admin/views.py b/ref-test/admin/views.py index 5ef7f96..78752a4 100644 --- a/ref-test/admin/views.py +++ b/ref-test/admin/views.py @@ -10,10 +10,8 @@ from werkzeug.security import check_password_hash from common.security.database import decrypt_find, decrypt_find_one from .models.users import User from flask_mail import Message -from main import app, db from uuid import uuid4 import secrets -from main import mail from datetime import datetime, date from .models.tests import Test from common.data_tools import get_default_dataset, get_time_options, available_datasets, get_datasets @@ -28,6 +26,8 @@ views = Blueprint( def admin_account_required(function): @wraps(function) def decorated_function(*args, **kwargs): + from main import db + from main import db if not db.users.find_one({}): flash('No administrator accounts have been registered. Please register an administrator account.', 'alert') return redirect(url_for('admin_auth.register')) @@ -37,6 +37,7 @@ def admin_account_required(function): def disable_on_registration(function): @wraps(function) def decorated_function(*args, **kwargs): + from main import db if db.users.find_one({}): return abort(404) return function(*args, **kwargs) @@ -46,6 +47,7 @@ def get_id_from_cookie(): return request.cookies.get('_id') def get_user_from_db(_id): + from main import db return db.users.find_one({'_id': _id}) def check_login(): @@ -76,6 +78,7 @@ def disable_if_logged_in(function): @admin_account_required @login_required def home(): + from main import db tests = db.tests.find() results = decrypt_find(db.entries, {}) current_tests = [ test for test in tests if test['expiry_date'].date() >= date.today() and test['start_date'].date() <= date.today() ] @@ -92,6 +95,7 @@ def home(): @admin_account_required @login_required def settings(): + from main import db users = decrypt_find(db.users, {}) users.sort(key= lambda x: x['username']) datasets = get_datasets() @@ -101,6 +105,7 @@ def settings(): @admin_account_required @login_required def users(): + from main import db, mail from .models.forms import CreateUserForm form = CreateUserForm() if request.method == 'GET': @@ -148,6 +153,7 @@ def users(): @admin_account_required @login_required def delete_user(_id:str): + from main import db, mail if _id == get_id_from_cookie(): flash('Cannot delete your own user account.', 'error') return redirect(url_for('admin_views.users')) @@ -197,6 +203,7 @@ def delete_user(_id:str): @admin_account_required @login_required def update_user(_id:str): + from main import db, mail if _id == get_id_from_cookie(): flash('Cannot delete your own user account.', 'error') return redirect(url_for('admin_views.users')) @@ -279,6 +286,7 @@ def questions(): @admin_account_required @login_required def delete_questions(): + from main import db, app filename = request.get_json()['filename'] data_files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json')) if any(filename in file for file in data_files): @@ -301,6 +309,7 @@ def delete_questions(): @admin_account_required @login_required def make_default_questions(): + from main import app filename = request.get_json()['filename'] data_files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json')) default_file_path = os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt') @@ -320,6 +329,7 @@ def make_default_questions(): @admin_account_required @login_required def tests(filter=''): + from main import db if not available_datasets(): flash('There are no available question datasets. Please upload a question dataset in order to set up an exam.', 'error') return redirect(url_for('admin_views.questions')) @@ -359,6 +369,7 @@ def tests(filter=''): @admin_account_required @login_required def create_test(): + from main import db from .models.forms import CreateTest form = CreateTest() form.dataset.choices = available_datasets() @@ -398,6 +409,7 @@ def create_test(): @admin_account_required @login_required def delete_test(): + from main import db _id = request.get_json()['_id'] if db.tests.find_one({'_id': _id}): return Test(_id = _id).delete() @@ -407,6 +419,7 @@ def delete_test(): @admin_account_required @login_required def view_test(_id): + from main import db from .models.forms import AddTimeAdjustment form = AddTimeAdjustment() test = decrypt_find_one(db.tests, {'_id': _id}) @@ -431,6 +444,7 @@ def delete_adjustment(_id): @admin_account_required @login_required def view_entries(): + from main import db entries = decrypt_find(db.entries, {}) return render_template('/admin/results.html', entries = entries) @@ -438,6 +452,7 @@ def view_entries(): @admin_account_required @login_required def view_entry(_id=''): + from main import db entry = decrypt_find_one(db.entries, {'_id': _id}) if request.method == 'GET': if not entry: @@ -468,6 +483,7 @@ def view_entry(_id=''): @admin_account_required @login_required def generate_certificate(): + from main import db _id = request.get_json()['_id'] entry = decrypt_find_one(db.entries, {'_id': _id}) if not entry: diff --git a/ref-test/common/data_tools.py b/ref-test/common/data_tools.py index 97916bd..7a42ffe 100644 --- a/ref-test/common/data_tools.py +++ b/ref-test/common/data_tools.py @@ -6,35 +6,38 @@ from glob import glob from random import shuffle from werkzeug.utils import secure_filename -from main import app, db from .security.database import decrypt_find_one def check_data_folder_exists(): + from main import app if not os.path.exists(app.config['DATA_FILE_DIRECTORY']): pathlib.Path(app.config['DATA_FILE_DIRECTORY']).mkdir(parents='True', exist_ok='True') def check_default_indicator(): + from main import app if not os.path.isfile(os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt')): open(os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt'),'w').close() def get_default_dataset(): check_default_indicator() + from main import app default_file_path = os.path.join(app.config['DATA_FILE_DIRECTORY'], '.default.txt') with open(default_file_path, 'r') as default_file: default = default_file.read() return default def available_datasets(): - files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json')) - default = get_default_dataset() - output = [] - for file in files: - filename = file.rsplit('/')[-1] - label = f'{filename[:-5]} (Default)' if filename == default else filename[:-5] - element = (filename, label) - output.append(element) - output.reverse() - return output + from main import app + files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json')) + default = get_default_dataset() + output = [] + for file in files: + filename = file.rsplit('/')[-1] + label = f'{filename[:-5]} (Default)' if filename == default else filename[:-5] + element = (filename, label) + output.append(element) + output.reverse() + return output def check_json_format(file): if not '.' in file.filename: @@ -58,6 +61,7 @@ def validate_json_contents(file): def store_data_file(file, default:bool=None): from admin.views import get_id_from_cookie + from main import app check_default_indicator() timestamp = datetime.utcnow() filename = '.'.join([timestamp.strftime('%Y%m%d%H%M%S'),'json']) @@ -201,6 +205,7 @@ def get_time_options(): return time_options def get_datasets(): + from main import app, db files = glob(os.path.join(app.config["DATA_FILE_DIRECTORY"],'*.json')) data = [] if files: diff --git a/ref-test/config.py b/ref-test/config.py index a1012c5..ad4cd32 100644 --- a/ref-test/config.py +++ b/ref-test/config.py @@ -5,8 +5,6 @@ class Config(object): TESTING = False SECRET_KEY = os.getenv('SECRET_KEY') - from dotenv import load_dotenv - load_dotenv() MONGO_INITDB_DATABASE = os.getenv('MONGO_INITDB_DATABASE') from urllib import parse MONGO_URI = f'mongodb://{os.getenv("MONGO_INITDB_USERNAME")}:{parse.quote_plus(os.getenv("MONGO_INITDB_PASSWORD"))}@{os.getenv("MONGO_DB_HOST_ALIAS")}:{os.getenv("MONGO_PORT")}/' @@ -32,21 +30,21 @@ class ProductionConfig(Config): pass class DevelopmentConfig(Config): - from dotenv import load_dotenv - load_dotenv() DEBUG = True SESSION_COOKIE_SECURE = False MONGO_INITDB_DATABASE = os.getenv('MONGO_INITDB_DATABASE') from urllib import parse MONGO_URI = f'mongodb://{os.getenv("MONGO_INITDB_USERNAME")}:{parse.quote_plus(os.getenv("MONGO_INITDB_PASSWORD"))}@localhost:{os.getenv("MONGO_PORT")}/' APP_HOST = '127.0.0.1' + MAIL_SERVER = 'localhost' MAIL_DEBUG = True MAIL_SUPPRESS_SEND = False -class TestingConfig(Config): - from dotenv import load_dotenv - load_dotenv() +class TestingConfig(DevelopmentConfig): TESTING = True SESSION_COOKIE_SECURE = False + MAIL_SERVER = os.getenv("MAIL_SERVER") MAIL_DEBUG = True - MAIL_SUPPRESS_SEND = False \ No newline at end of file + MAIL_SUPPRESS_SEND = False + from urllib import parse + MONGO_URI = f'mongodb://{os.getenv("MONGO_INITDB_USERNAME")}:{parse.quote_plus(os.getenv("MONGO_INITDB_PASSWORD"))}@{os.getenv("MONGO_DB_HOST_ALIAS")}:{os.getenv("MONGO_PORT")}/' \ No newline at end of file diff --git a/ref-test/main.py b/ref-test/main.py index bf1899a..9729e0a 100644 --- a/ref-test/main.py +++ b/ref-test/main.py @@ -10,32 +10,11 @@ from flask_wtf.csrf import CSRFProtect, CSRFError from flask_mail import Mail from common.security import check_keyfile_exists, generate_keyfile +import config -app = Flask(__name__) -app.config.from_object('config.DevelopmentConfig') - -Bootstrap(app) -csrf = CSRFProtect(app) - -@app.errorhandler(CSRFError) -def csrf_error_handler(error): - return jsonify({ 'error': 'Could not validate a secure connection.'} ), 400 - -try: - mongo = MongoClient(app.config['MONGO_URI']) - db = mongo[app.config['MONGO_INITDB_DATABASE']] -except ConnectionFailure as error: - print(error) - -try: - mail = Mail(app) -except Exception as error: - print(error) - -if __name__ == '__main__': - - if not check_keyfile_exists(): - generate_keyfile() +def create_app(): + app = Flask(__name__) + app.config.from_object(config.TestingConfig()) from common.blueprints import cookie_consent @@ -80,5 +59,23 @@ if __name__ == '__main__': @app.errorhandler(404) def _404_handler(e): return render_template('/quiz/404.html'), 404 + + @app.errorhandler(CSRFError) + def csrf_error_handler(error): + return jsonify({ 'error': 'Could not validate a secure connection.'} ), 400 + if not check_keyfile_exists(): + generate_keyfile() + + Bootstrap(app) + csrf = CSRFProtect(app) + + return app + +app = create_app() +mongo = MongoClient(app.config['MONGO_URI']) +db = mongo[app.config['MONGO_INITDB_DATABASE']] +mail = Mail(app) + +if __name__ == '__main__': app.run(host=app.config['APP_HOST']) \ No newline at end of file diff --git a/ref-test/quiz/views.py b/ref-test/quiz/views.py index 44c5464..fcc87ca 100644 --- a/ref-test/quiz/views.py +++ b/ref-test/quiz/views.py @@ -8,7 +8,6 @@ from flask_mail import Message from pymongo.collection import ReturnDocument -from main import app, db, mail from common.security import encrypt from common.data_tools import generate_questions, evaluate_answers from common.security.database import decrypt_find_one @@ -24,6 +23,7 @@ views = Blueprint( @views.route('/') @views.route('/home/') def home(): + from main import db _id = session.get('_id') if _id and db.entries.find_one({'_id': _id}): return redirect(url_for('quiz_views.start_quiz')) @@ -31,6 +31,7 @@ def home(): @views.route('/instructions/') def instructions(): + from main import db _id = session.get('_id') if _id and db.entries.find_one({'_id': _id}): return redirect(url_for('quiz_views.start_quiz')) @@ -38,6 +39,7 @@ def instructions(): @views.route('/start/', methods = ['GET', 'POST']) def start(): + from main import db from .forms import StartQuiz form = StartQuiz() if request.method == 'GET': @@ -85,6 +87,7 @@ def start(): @views.route('/api/questions/', methods=['POST']) def fetch_questions(): + from main import app, db _id = request.get_json()['_id'] entry = db.entries.find_one({'_id': _id}) if not entry: @@ -124,6 +127,7 @@ def fetch_questions(): @views.route('/test/') def start_quiz(): + from main import db _id = session.get('_id') if not _id or not db.entries.find_one({'_id': _id}): flash('Your log in was not recognised. Please sign in to the quiz again.', 'error') @@ -132,6 +136,7 @@ def start_quiz(): @views.route('/api/submit/', methods=['POST']) def submit_quiz(): + from main import app, db _id = request.get_json()['_id'] answers = request.get_json()['answers'] entry = db.entries.find_one({'_id': _id}) @@ -161,6 +166,7 @@ def submit_quiz(): @views.route('/result/') def result(): + from main import db, mail _id = session.get('_id') entry = decrypt_find_one(db.entries, {'_id': _id}) if not entry: diff --git a/ref-test/wsgi.py b/ref-test/wsgi.py new file mode 100644 index 0000000..b9a49b9 --- /dev/null +++ b/ref-test/wsgi.py @@ -0,0 +1,4 @@ +from main import app + +if __name__ == '__main__': + app.run() \ No newline at end of file