Commit 2eb4b69e authored by Anthony Jacob's avatar Anthony Jacob
Browse files

add basic JWT auth

parent 1aa573aa
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
import os
from dotenv import load_dotenv
from datetime import timedelta

# Load environment variables
load_dotenv()
@@ -11,8 +12,8 @@ class Config:
    REDIS_HOST = os.getenv('REDIS_HOST')
    REDIS_PORT = os.getenv('REDIS_PORT')
    JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY')
    JWT_ACCESS_TOKEN_EXPIRES = os.getenv('JWT_ACCESS_TOKEN_EXPIRES')
    JWT_REFRESH_TOKEN_EXPIRES = os.getenv('JWT_REFRESH_TOKEN_EXPIRES')
    JWT_ACCESS_TOKEN_EXPIRES =  timedelta(seconds=int(os.getenv('JWT_ACCESS_TOKEN_EXPIRES')))
    JWT_REFRESH_TOKEN_EXPIRES =  timedelta(seconds=int(os.getenv('JWT_REFRESH_TOKEN_EXPIRES')))
    POSTGRES_HOST = os.getenv('POSTGRES_HOST')
    POSTGRES_PORT= os.getenv('POSTGRES_PORT')
    POSTGRES_DB= os.getenv('POSTGRES_DB')

app/controller/auth.py

0 → 100644
+56 −0
Original line number Diff line number Diff line
from flask import Blueprint, request, jsonify, current_app
from helpers.security import check_auth, revoke_jwt
from helpers.limiter import limiter
from flask_jwt_extended import (
    JWTManager, create_access_token, create_refresh_token,
    jwt_required, get_jwt_identity, get_jwt
)

auth_bp = Blueprint('auth', __name__)

@auth_bp.route("/login", methods=["POST"])
def login():
    """Authenticate user and return access & refresh tokens."""
    data = request.json
    username = data.get("username")
    password = data.get("password")

    print(data)

    if not username or not password or not check_auth(username=username, password=password):
        return jsonify({"error": "Invalid credentials"}), 401

    # Create tokens
    # access_token = create_access_token(identity={"username": username, "role": user["role"]})
    # refresh_token = create_refresh_token(identity={"username": username, "role": user["role"]})
    access_token = create_access_token(identity={"username": username})
    refresh_token = create_refresh_token(identity={"username": username})

    return jsonify({"access_token": access_token, "refresh_token": refresh_token})


@auth_bp.route("/logout", methods=["POST"])
@jwt_required()
def logout():
    """Blacklist the token to prevent reuse."""
    jti = get_jwt()["jti"]  # JWT unique identifier
    revoke_jwt(jti)
    return jsonify({"message": f"Token revoked {jti}"}), 200


@auth_bp.route("/refresh", methods=["POST"])
@jwt_required(refresh=True)
def refresh():
    """Generate a new access token using the refresh token."""
    identity = get_jwt_identity()
    new_access_token = create_access_token(identity=identity)
    return jsonify({"access_token": new_access_token})

@auth_bp.route("/CheckConnected", methods=["GET"])
@jwt_required(optional=True)
def optionally_protected():
    current_identity = get_jwt_identity()
    if current_identity:
        return jsonify(logged_in_as=current_identity)
    else:
        return jsonify(logged_in_as="anonymous user")
 No newline at end of file
+47 −0
Original line number Diff line number Diff line
@@ -95,3 +95,50 @@ def check_auth(username, password):
    except Exception as e:
        current_app.logger.error(f"Database error while checking API key: {e}")
        return False  # Return False in case of any DB error


def revoke_jwt(jti):
    identity = get_jwt_identity()

    getUserQuery = """SELECT id from "user" WHERE UPPER(login) = UPPER(%s)"""


    InsertRevokeQuery = """INSERT INTO revoked_jwt (jti, user_id)
               VALUES (%s, %s)"""

    try:
        with current_app.db_pool.connection() as conn:
            with conn.cursor() as cur:
                cur.execute(getUserQuery, (identity['username'],))
                row = cur.fetchone()
                if row:
                    print(row)
                    user_id = row[0]
                    cur.execute(InsertRevokeQuery, (jti, user_id))  # Insert both jti and user_id
                    conn.commit()  # Commit the transaction after the insert
                    return True  # Successfully revoked JWT

        return False
    except Exception as e:
        current_app.logger.error(f"Database error while revoking JWT: {e}")
        return False  # Return False in case of any DB error

def is_jwt_revoked(jti):
    # Check if a JWT has been revoked by looking for the jti in the revoked_jwt table.
    query = """SELECT 1
               FROM revoked_jwt
               WHERE jti = %s
               LIMIT 1"""

    try:
        with current_app.db_pool.connection() as conn:
            with conn.cursor() as cur:
                cur.execute(query, (jti,))
                row = cur.fetchone()
                if row:
                    return True
                return False

    except Exception as e:
        current_app.logger.error(f"Database error while checking JWT revocation: {e}")
        return False
+11 −1
Original line number Diff line number Diff line
from flask import Flask, jsonify
from psycopg_pool import ConnectionPool
from redis import Redis
from flask_jwt_extended import JWTManager
from app.config import Config
from app.controller import diploma,healthcheck
from app.controller import diploma,healthcheck,auth
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from helpers.limiter import limiter
from helpers.security import is_jwt_revoked


app = Flask(__name__)
@@ -28,6 +30,12 @@ app.redis = Redis(host=app.config['REDIS_HOST'], port=app.config['REDIS_PORT'])

limiter.init_app(app)

app.jwt = JWTManager(app)

@app.jwt.token_in_blocklist_loader
def check_if_token_revoked(jwt_header, jwt_payload):
    return is_jwt_revoked(jwt_payload["jti"])


@app.errorhandler(429)
def ratelimit_error(error):
@@ -42,6 +50,8 @@ app.register_blueprint(diploma.diploma_bp)

app.register_blueprint(healthcheck.healthcheck_bp)

app.register_blueprint(auth.auth_bp, url_prefix='/auth')


if __name__ == "__main__":
    app.run(debug=True)