Applied auto pep 8 changes

This commit is contained in:
Luke Else 2024-01-21 22:06:06 +00:00
parent f227727c74
commit 44c1ee03ba
22 changed files with 156 additions and 100 deletions

7
app.py
View File

@ -7,6 +7,8 @@ from controllers.web.endpoints import blueprint
Initialises any components that are needed at runtime such as the Initialises any components that are needed at runtime such as the
Database manager... Database manager...
''' '''
def main(): def main():
app = Flask(__name__) app = Flask(__name__)
@ -19,9 +21,10 @@ def main():
else: else:
app.secret_key = secret_key app.secret_key = secret_key
# Register a blueprint # Register a blueprint
app.register_blueprint(blueprint) app.register_blueprint(blueprint)
app.run(debug=True, host="0.0.0.0") app.run(debug=True, host="0.0.0.0")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,6 +1,7 @@
from .database import DatabaseController from .database import DatabaseController
from models.category import Category from models.category import Category
class CategoryController(DatabaseController): class CategoryController(DatabaseController):
FIELDS = ['id', 'name'] FIELDS = ['id', 'name']
@ -18,7 +19,6 @@ class CategoryController(DatabaseController):
) )
self._conn.commit() self._conn.commit()
def read(self, id: int = 0) -> Category | None: def read(self, id: int = 0) -> Category | None:
params = [ params = [
id id
@ -32,12 +32,11 @@ class CategoryController(DatabaseController):
if row == None: if row == None:
return None return None
params = dict(zip(self.FIELDS, row)) params = dict(zip(self.FIELDS, row))
obj = self.new_instance(Category, params) obj = self.new_instance(Category, params)
return obj
return obj
def read_all(self) -> list[Category] | None: def read_all(self) -> list[Category] | None:
cursor = self._conn.execute( cursor = self._conn.execute(
@ -47,18 +46,18 @@ class CategoryController(DatabaseController):
if rows == None: if rows == None:
return None return None
categories = list() categories = list()
for category in rows: for category in rows:
params = dict(zip(self.FIELDS, category)) params = dict(zip(self.FIELDS, category))
obj = self.new_instance(Category, params) obj = self.new_instance(Category, params)
categories.append(obj) categories.append(obj)
return categories return categories
def update(self): def update(self):
print("Doing work") print("Doing work")
def delete(self): def delete(self):
print("Doing work") print("Doing work")

View File

@ -1,8 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Mapping, Any from typing import Mapping, Any
import sqlite3 import sqlite3
import os import os
class DatabaseController(ABC): class DatabaseController(ABC):
__data_dir = "./data/" __data_dir = "./data/"
__db_name = "wmgzon.db" __db_name = "wmgzon.db"
@ -13,13 +14,13 @@ class DatabaseController(ABC):
__sqlitefile = __data_dir + __db_name __sqlitefile = __data_dir + __db_name
def __init__(self): def __init__(self):
self._conn = None self._conn = None
try: try:
# Creates a connection and specifies a flag to parse all types back down into # Creates a connection and specifies a flag to parse all types back down into
# Python declared types e.g. date & time # Python declared types e.g. date & time
self._conn = sqlite3.connect(self.__sqlitefile, detect_types=sqlite3.PARSE_DECLTYPES) self._conn = sqlite3.connect(
self.__sqlitefile, detect_types=sqlite3.PARSE_DECLTYPES)
except sqlite3.Error as e: except sqlite3.Error as e:
# Close the connection if still open # Close the connection if still open
if self._conn: if self._conn:
@ -27,17 +28,18 @@ class DatabaseController(ABC):
print(e) print(e)
def __del__(self): def __del__(self):
if self._conn != None: if self._conn != None:
self._conn.close() self._conn.close()
""" Takes a dictionary of fields and returns the object """ Takes a dictionary of fields and returns the object
with those fields populated """ with those fields populated """
def new_instance(self, of: type, with_fields: Mapping[str, Any]): def new_instance(self, of: type, with_fields: Mapping[str, Any]):
obj = of.__new__(of) obj = of.__new__(of)
for attr, value in with_fields.items(): for attr, value in with_fields.items():
setattr(obj, attr, value) setattr(obj, attr, value)
return obj return obj
""" """
Set of CRUD methods to allow for Data manipulation on the backend Set of CRUD methods to allow for Data manipulation on the backend
""" """
@ -53,7 +55,7 @@ class DatabaseController(ABC):
@abstractmethod @abstractmethod
def update(self): def update(self):
pass pass
@abstractmethod @abstractmethod
def delete(self): def delete(self):
pass pass

View File

@ -1,8 +1,10 @@
from .database import DatabaseController from .database import DatabaseController
from models.products.product import Product from models.products.product import Product
class ProductController(DatabaseController): class ProductController(DatabaseController):
FIELDS = ['id', 'name', 'image', 'description', 'cost', 'category', 'sellerID', 'postedDate', 'quantityAvailable'] FIELDS = ['id', 'name', 'image', 'description', 'cost',
'category', 'sellerID', 'postedDate', 'quantityAvailable']
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -25,7 +27,6 @@ class ProductController(DatabaseController):
) )
self._conn.commit() self._conn.commit()
def read(self, name: str = "") -> list[Product] | None: def read(self, name: str = "") -> list[Product] | None:
params = [ params = [
"%" + name + "%" "%" + name + "%"
@ -39,24 +40,23 @@ class ProductController(DatabaseController):
if rows == None: if rows == None:
return None return None
products = list() products = list()
# Create an object for each row # Create an object for each row
for product in rows: for product in rows:
params = dict(zip(self.FIELDS, product)) params = dict(zip(self.FIELDS, product))
obj = self.new_instance(Product, params) obj = self.new_instance(Product, params)
products.append(obj) products.append(obj)
return products
return products
def read_all(self, category: str = "", search_term: str = "") -> list[Product] | None: def read_all(self, category: str = "", search_term: str = "") -> list[Product] | None:
params = [ params = [
"%" + category + "%", "%" + category + "%",
"%" + search_term + "%" "%" + search_term + "%"
] ]
cursor = self._conn.execute( cursor = self._conn.execute(
"""SELECT * FROM Products """SELECT * FROM Products
INNER JOIN Categories ON Products.categoryID = Categories.id INNER JOIN Categories ON Products.categoryID = Categories.id
@ -69,19 +69,19 @@ class ProductController(DatabaseController):
if len(rows) == 0: if len(rows) == 0:
return None return None
products = list() products = list()
# Create an object for each row # Create an object for each row
for product in rows: for product in rows:
params = dict(zip(self.FIELDS, product)) params = dict(zip(self.FIELDS, product))
obj = self.new_instance(Product, params) obj = self.new_instance(Product, params)
products.append(obj) products.append(obj)
return products return products
def update(self): def update(self):
print("Doing work") print("Doing work")
def delete(self): def delete(self):
print("Doing work") print("Doing work")

View File

@ -3,8 +3,10 @@ from models.users.user import User
from models.users.customer import Customer from models.users.customer import Customer
from models.users.seller import Seller from models.users.seller import Seller
class UserController(DatabaseController): class UserController(DatabaseController):
FIELDS = ['id', 'username', 'password', 'firstName', 'lastName', 'email', 'phone', 'role'] FIELDS = ['id', 'username', 'password', 'firstName',
'lastName', 'email', 'phone', 'role']
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -44,13 +46,12 @@ class UserController(DatabaseController):
type = Customer type = Customer
if row[7] == "Seller": if row[7] == "Seller":
type = Seller type = Seller
obj = self.new_instance(type, params) obj = self.new_instance(type, params)
return obj return obj
return None return None
def read_id(self, id: int) -> User | None: def read_id(self, id: int) -> User | None:
params = [ params = [
id id
@ -69,13 +70,14 @@ class UserController(DatabaseController):
type = Customer type = Customer
if row[7] == "Seller": if row[7] == "Seller":
type = Seller type = Seller
obj = self.new_instance(type, params) obj = self.new_instance(type, params)
return obj return obj
return None return None
def update(self): def update(self):
print("Doing work") print("Doing work")
def delete(self): def delete(self):
print("Doing work") print("Doing work")

View File

@ -16,15 +16,15 @@ blueprint.register_blueprint(product.blueprint)
# Function that returns a given user class based on the ID in the session # Function that returns a given user class based on the ID in the session
@blueprint.context_processor @blueprint.context_processor
def get_user() -> dict[User|None]: def get_user() -> dict[User | None]:
# Get the user based on the user ID # Get the user based on the user ID
user_id = session.get('user_id') user_id = session.get('user_id')
user = None user = None
if user_id != None: if user_id != None:
db = UserController() db = UserController()
user = db.read_id(user_id) user = db.read_id(user_id)
return dict(user=user) return dict(user=user)

View File

@ -15,14 +15,18 @@ import pathlib
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
def allowed_file(filename): def allowed_file(filename):
return '.' in filename and \ return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
blueprint = Blueprint("products", __name__, url_prefix="/products") blueprint = Blueprint("products", __name__, url_prefix="/products")
# Global context to enable the categories to be accessed # Global context to enable the categories to be accessed
# from any view # from any view
@blueprint.context_processor @blueprint.context_processor
def category_list(): def category_list():
database = CategoryController() database = CategoryController()
@ -30,6 +34,8 @@ def category_list():
return dict(categories=categories) return dict(categories=categories)
# Loads the front product page # Loads the front product page
@blueprint.route('/') @blueprint.route('/')
def index(): def index():
database = ProductController() database = ProductController()
@ -38,14 +44,16 @@ def index():
# No Products visible # No Products visible
if products == None: if products == None:
flash("No Products available") flash("No Products available")
return render_template('index.html', content="content.html", products = products) return render_template('index.html', content="content.html", products=products)
# Loads a given product category page # Loads a given product category page
@blueprint.route('/<string:category>') @blueprint.route('/<string:category>')
def category(category: str): def category(category: str):
database = ProductController() database = ProductController()
# Check to see if there is a custome search term # Check to see if there is a custome search term
search_term = request.args.get("search", type=str) search_term = request.args.get("search", type=str)
if search_term != None: if search_term != None:
@ -57,10 +65,12 @@ def category(category: str):
# No Products visible # No Products visible
if products == None: if products == None:
flash(f"No Products available in {category}") flash(f"No Products available in {category}")
return render_template('index.html', content="content.html", products = products, category = category) return render_template('index.html', content="content.html", products=products, category=category)
# Loads a given product based on ID # Loads a given product based on ID
@blueprint.route('/<int:id>') @blueprint.route('/<int:id>')
def id(id: int): def id(id: int):
return "ID: " + str(id) return "ID: " + str(id)
@ -70,12 +80,12 @@ def id(id: int):
@blueprint.route('/add') @blueprint.route('/add')
def display_add_product(): def display_add_product():
user_id = session.get('user_id') user_id = session.get('user_id')
# User must be logged in to view this page # User must be logged in to view this page
if user_id == None: if user_id == None:
flash("Please Login to view this page") flash("Please Login to view this page")
return redirect('/login') return redirect('/login')
db = UserController() db = UserController()
user = db.read_id(user_id) user = db.read_id(user_id)
if user == None or user.role != "Seller": if user == None or user.role != "Seller":
@ -89,12 +99,12 @@ def display_add_product():
@blueprint.post('/add') @blueprint.post('/add')
def add_product(): def add_product():
user_id = session.get('user_id') user_id = session.get('user_id')
# User must be logged in to view this page # User must be logged in to view this page
if user_id == None: if user_id == None:
flash("Please Login to view this page") flash("Please Login to view this page")
return redirect('/login', code=302) return redirect('/login', code=302)
db = UserController() db = UserController()
user = db.read_id(user_id) user = db.read_id(user_id)
if user == None or user.role != "Seller": if user == None or user.role != "Seller":
@ -102,7 +112,7 @@ def add_product():
return redirect('/', code=302) return redirect('/', code=302)
file = request.files.get('image') file = request.files.get('image')
# Ensure that the correct file type is uploaded # Ensure that the correct file type is uploaded
if file == None or not allowed_file(file.filename): if file == None or not allowed_file(file.filename):
flash("Invalid File Uploaded") flash("Invalid File Uploaded")

View File

@ -10,13 +10,17 @@ from hashlib import sha512
# Blueprint to append user endpoints to # Blueprint to append user endpoints to
blueprint = Blueprint("users", __name__) blueprint = Blueprint("users", __name__)
### LOGIN FUNCTIONALITY # LOGIN FUNCTIONALITY
# Function responsible for delivering the Login page for the site # Function responsible for delivering the Login page for the site
@blueprint.route('/login') @blueprint.route('/login')
def display_login(): def display_login():
return render_template('index.html', content="login.html") return render_template('index.html', content="login.html")
# Function responsible for handling logins to the site # Function responsible for handling logins to the site
@blueprint.post('/login') @blueprint.post('/login')
def login(): def login():
database = UserController() database = UserController()
@ -28,7 +32,7 @@ def login():
error = "No user found with the username " + request.form['username'] error = "No user found with the username " + request.form['username']
flash(error) flash(error)
return redirect("/login") return redirect("/login")
# Incorrect Password # Incorrect Password
if sha512(request.form['password'].encode()).hexdigest() != user.password: if sha512(request.form['password'].encode()).hexdigest() != user.password:
error = "Incorrect Password" error = "Incorrect Password"
@ -39,13 +43,15 @@ def login():
return redirect("/") return redirect("/")
### SIGNUP FUNCTIONALITY # SIGNUP FUNCTIONALITY
# Function responsible for delivering the Signup page for the site # Function responsible for delivering the Signup page for the site
@blueprint.route('/signup') @blueprint.route('/signup')
def display_signup(): def display_signup():
return render_template('index.html', content="signup.html") return render_template('index.html', content="signup.html")
# Function responsible for handling signups to the site # Function responsible for handling signups to the site
@blueprint.post('/signup') @blueprint.post('/signup')
def signup(): def signup():
database = UserController() database = UserController()
@ -60,7 +66,8 @@ def signup():
if request.form.get('seller'): if request.form.get('seller'):
user = Seller( user = Seller(
request.form['username'], request.form['username'],
sha512(request.form['password'].encode()).hexdigest(), # Hashed as soon as it is recieved on the backend # Hashed as soon as it is recieved on the backend
sha512(request.form['password'].encode()).hexdigest(),
request.form['firstname'], request.form['firstname'],
request.form['lastname'], request.form['lastname'],
request.form['email'], request.form['email'],
@ -69,23 +76,24 @@ def signup():
else: else:
user = Customer( user = Customer(
request.form['username'], request.form['username'],
sha512(request.form['password'].encode()).hexdigest(), # Hashed as soon as it is recieved on the backend # Hashed as soon as it is recieved on the backend
sha512(request.form['password'].encode()).hexdigest(),
request.form['firstname'], request.form['firstname'],
request.form['lastname'], request.form['lastname'],
request.form['email'], request.form['email'],
"123" "123"
) )
database.create(user) database.create(user)
# Code 307 Preserves the original request (POST) # Code 307 Preserves the original request (POST)
return redirect("/login", code=307) return redirect("/login", code=307)
### SIGN OUT FUNCTIONALITY # SIGN OUT FUNCTIONALITY
# Function responsible for handling logouts from the site # Function responsible for handling logouts from the site
@blueprint.route('/logout') @blueprint.route('/logout')
def logout(): def logout():
# Clear the current user from the session if they are logged in # Clear the current user from the session if they are logged in
session.pop('user_id', None) session.pop('user_id', None)
return redirect("/") return redirect("/")

View File

@ -2,6 +2,7 @@ class Category:
''' '''
Constructor for a category object Constructor for a category object
''' '''
def __init__(self): def __init__(self):
self.id = 0 self.id = 0
self.name = "" self.name = ""

View File

@ -1,13 +1,15 @@
from datetime import datetime from datetime import datetime
class Order: class Order:
''' '''
Constructor for an order object Constructor for an order object
''' '''
def __init__(self): def __init__(self):
self.id = 0 self.id = 0
self.sellerID = 0 self.sellerID = 0
self.customerID = 0 self.customerID = 0
self.products = list() self.products = list()
self.totalCost = 0.0 self.totalCost = 0.0
self.orderDate = datetime.now() self.orderDate = datetime.now()

View File

@ -1,13 +1,14 @@
from product import Product from product import Product
class CarPart(Product): class CarPart(Product):
''' '''
Constructor for a car part Constructor for a car part
Contains additional information that is only relevant for car parts Contains additional information that is only relevant for car parts
''' '''
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.make = "" self.make = ""
self.compatibleVehicles = list() self.compatibleVehicles = list()

View File

@ -1,9 +1,11 @@
from datetime import datetime from datetime import datetime
class Product: class Product:
''' '''
Base class for a product Base class for a product
''' '''
def __init__(self): def __init__(self):
self.id = 0 self.id = 0
self.name = "" self.name = ""
@ -20,7 +22,8 @@ class Product:
No additional properties are assigned to the customer No additional properties are assigned to the customer
''' '''
def __init__(self, name: str, image: str, description: str, cost: float, category: int,
def __init__(self, name: str, image: str, description: str, cost: float, category: int,
seller_id: int, posted_date: datetime, quantity_available: int): seller_id: int, posted_date: datetime, quantity_available: int):
self.id = 0 self.id = 0
self.name = name self.name = name
@ -31,4 +34,3 @@ class Product:
self.sellerID = seller_id self.sellerID = seller_id
self.postedDate = posted_date self.postedDate = posted_date
self.quantityAvailable = quantity_available self.quantityAvailable = quantity_available

View File

@ -1,12 +1,14 @@
from .user import User from .user import User
class Admin(User): class Admin(User):
''' '''
Class constructor to instatiate an admin object Class constructor to instatiate an admin object
No additional properties are assigned to the admin No additional properties are assigned to the admin
''' '''
def __init__(self, username: str, password: str, firstname: str,
def __init__(self, username: str, password: str, firstname: str,
lastname: str, email: str, phone: str): lastname: str, email: str, phone: str):
super().__init__( super().__init__(
username, password, firstname, lastname, email, phone, "Admin" username, password, firstname, lastname, email, phone, "Admin"

View File

@ -1,14 +1,15 @@
from .user import User from .user import User
class Customer(User): class Customer(User):
''' '''
Class constructor to instatiate a customer object Class constructor to instatiate a customer object
No additional properties are assigned to the customer No additional properties are assigned to the customer
''' '''
def __init__(self, username: str, password: str, firstname: str,
def __init__(self, username: str, password: str, firstname: str,
lastname: str, email: str, phone: str): lastname: str, email: str, phone: str):
super().__init__( super().__init__(
username, password, firstname, lastname, email, phone, "Customer" username, password, firstname, lastname, email, phone, "Customer"
) )

View File

@ -1,12 +1,14 @@
from .user import User from .user import User
class Seller(User): class Seller(User):
''' '''
Class constructor to instatiate a customer object Class constructor to instatiate a customer object
No additional properties are assigned to the customer No additional properties are assigned to the customer
''' '''
def __init__(self, username: str, password: str, firstname: str,
def __init__(self, username: str, password: str, firstname: str,
lastname: str, email: str, phone: str): lastname: str, email: str, phone: str):
super().__init__( super().__init__(
username, password, firstname, lastname, email, phone, "Seller" username, password, firstname, lastname, email, phone, "Seller"

View File

@ -1,9 +1,11 @@
from abc import ABC from abc import ABC
class User(ABC): class User(ABC):
""" Functional Class constructor to initialise all properties in the base object """ Functional Class constructor to initialise all properties in the base object
with a value """ with a value """
def __init__(self, username: str, password: str, firstname: str,
def __init__(self, username: str, password: str, firstname: str,
lastname: str, email: str, phone: str, role: str): lastname: str, email: str, phone: str, role: str):
self.id = 0 self.id = 0
self.username = username self.username = username
@ -12,4 +14,4 @@ class User(ABC):
self.lastName = lastname self.lastName = lastname
self.email = email self.email = email
self.phone = phone self.phone = phone
self.role= role self.role = role

View File

@ -14,14 +14,14 @@ def create_connection(path: str, filename: str):
print("Database file open") print("Database file open")
# Execute creation scripts # Execute creation scripts
sql = open("scripts/create_tables.sql", "r"); sql = open("scripts/create_tables.sql", "r")
conn.executescript(sql.read()) conn.executescript(sql.read())
print("Table creation complete") print("Table creation complete")
# Populate with test data if we are in Test Mode # Populate with test data if we are in Test Mode
if os.environ.get("ENVIRON") == "test": if os.environ.get("ENVIRON") == "test":
sql = open("scripts/test_data.sql", "r"); sql = open("scripts/test_data.sql", "r")
conn.executescript(sql.read()) conn.executescript(sql.read())
except sqlite3.Error as e: except sqlite3.Error as e:
@ -31,12 +31,15 @@ def create_connection(path: str, filename: str):
conn.close() conn.close()
# Ensure a directory is created given a path to it # Ensure a directory is created given a path to it
def create_directory(dir: str): def create_directory(dir: str):
try: try:
os.makedirs(dir) os.makedirs(dir)
except FileExistsError: except FileExistsError:
pass pass
def remove_file(dir: str): def remove_file(dir: str):
try: try:
os.remove(dir) os.remove(dir)
@ -44,7 +47,6 @@ def remove_file(dir: str):
pass pass
dir = r"./data/" dir = r"./data/"
db_name = r"wmgzon.db" db_name = r"wmgzon.db"

View File

@ -1 +0,0 @@

View File

@ -1,8 +1,8 @@
# Ensure test environment is set before using # Ensure test environment is set before using
import scripts.create_database
import os import os
# Setup test environment variables # Setup test environment variables
os.environ["ENVIRON"] = "test" os.environ["ENVIRON"] = "test"
# Runs the database creation scripts # Runs the database creation scripts
import scripts.create_database

View File

@ -5,22 +5,26 @@ from controllers.database.product import ProductController
from models.products.product import Product from models.products.product import Product
product = Product( product = Product(
"product", "product",
"image.png", "image.png",
"description", "description",
10.00, 10.00,
1, 1,
1, 1,
datetime.now(), datetime.now(),
1 1
) )
# Tests a new product can be created # Tests a new product can be created
def test_create_product(): def test_create_product():
db = ProductController() db = ProductController()
db.create(product) db.create(product)
# Tests the database maintains integrity when we try and add a product with the same details # Tests the database maintains integrity when we try and add a product with the same details
@pytest.mark.skip @pytest.mark.skip
def test_duplicate_product(): def test_duplicate_product():
db = ProductController() db = ProductController()
@ -28,11 +32,13 @@ def test_duplicate_product():
db.create(product) db.create(product)
# Tests that products can be refined by category # Tests that products can be refined by category
def test_search_category(): def test_search_category():
db = ProductController() db = ProductController()
# Check each category for correct amount of test products # Check each category for correct amount of test products
assert len(db.read_all("Car Parts")) == 9 + 1 # Added in previous test assert len(db.read_all("Car Parts")) == 9 + 1 # Added in previous test
assert len(db.read_all("Books")) == 9 assert len(db.read_all("Books")) == 9
assert db.read_all("Phones") == None assert db.read_all("Phones") == None
@ -48,6 +54,8 @@ def test_search_term():
assert db.read_all(search_term="not_test") == None assert db.read_all(search_term="not_test") == None
# Test we the same product details get returned from the database # Test we the same product details get returned from the database
def test_read_product(): def test_read_product():
db = ProductController() db = ProductController()

View File

@ -5,35 +5,41 @@ from models.users.customer import Customer
from models.users.seller import Seller from models.users.seller import Seller
customer = Customer( customer = Customer(
"testcustomer", "testcustomer",
"Password1", "Password1",
"firstname", "firstname",
"lastname", "lastname",
"test@test", "test@test",
"123456789" "123456789"
) )
seller = Seller( seller = Seller(
"testseller", "testseller",
"Password1", "Password1",
"firstname", "firstname",
"lastname", "lastname",
"seller@seller", "seller@seller",
"987654321" "987654321"
) )
# Tests a new user can be created # Tests a new user can be created
def test_create_user(): def test_create_user():
db = UserController() db = UserController()
db.create(customer) db.create(customer)
# Tests the database maintains integrity when we try and add a user with the same details # Tests the database maintains integrity when we try and add a user with the same details
def test_duplicate_user(): def test_duplicate_user():
db = UserController() db = UserController()
with pytest.raises(sqlite3.IntegrityError): with pytest.raises(sqlite3.IntegrityError):
db.create(customer) db.create(customer)
# Test we the same user details get returned from the database # Test we the same user details get returned from the database
def test_read_user(): def test_read_user():
db = UserController() db = UserController()
@ -52,6 +58,8 @@ def test_create_seller():
db.create(seller) db.create(seller)
# Test that the same seller details get returned from the database # Test that the same seller details get returned from the database
def test_read_seller(): def test_read_seller():
db = UserController() db = UserController()

View File

@ -1,9 +1,11 @@
import pycodestyle import pycodestyle
# Tests files to ensure they conform to pep8 standards # Tests files to ensure they conform to pep8 standards
def test_pep8_conformance(): def test_pep8_conformance():
"""Test that we conform to PEP8.""" """Test that we conform to PEP8."""
pep8style = pycodestyle.StyleGuide() pep8style = pycodestyle.StyleGuide()
dirs = ["./controllers", "./models", "./scripts", "./tests"] dirs = ["./controllers", "./models", "./scripts", "./tests"]
result = pep8style.check_files(dirs) result = pep8style.check_files(dirs)
assert result.total_errors == 0 assert result.total_errors == 0