Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ jobs:
- name: Check out source code
uses: actions/checkout@v4

- name: Install ruff
run: pip install ruff
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Ruff lint
run: ruff check .
- name: Install pre-commit
run: pip install pre-commit

- name: Ruff format
run: ruff format --check --diff .
- name: Run pre-commit hooks
run: pre-commit run --all-files --show-diff-on-failure
82 changes: 80 additions & 2 deletions api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,93 @@

"""User authentication utilities"""

from typing import Optional

import jwt as pyjwt
from fastapi_users import exceptions, models
from fastapi_users.authentication import (
AuthenticationBackend,
BearerTransport,
JWTStrategy,
)
from fastapi_users.jwt import SecretType, decode_jwt
from fastapi_users.manager import BaseUserManager
from passlib.context import CryptContext

from .config import AuthSettings


class DualSecretJWTStrategy(JWTStrategy):
"""JWTStrategy that accepts tokens signed with either of two secrets.

Tokens are always *written* with the primary secret. On *read*, the
primary secret is tried first; if verification fails **and** a unified
secret is configured, the token is retried with the unified secret.
"""

def __init__(
self,
secret: SecretType,
lifetime_seconds: Optional[int],
algorithm: str = "HS256",
unified_secret: str = "",
):
super().__init__(
secret=secret,
lifetime_seconds=lifetime_seconds,
algorithm=algorithm,
)
self.unified_secret = unified_secret

async def read_token(
self,
token: Optional[str],
user_manager: BaseUserManager[models.UP, models.ID],
) -> Optional[models.UP]:
if token is None:
return None

# Try primary secret first
user = await self._decode_and_lookup(
token, self.decode_key, user_manager
)
if user is not None:
return user

# Fallback to unified secret
if self.unified_secret:
return await self._decode_and_lookup(
token, self.unified_secret, user_manager
)

return None

async def _decode_and_lookup(
self,
token: str,
secret: SecretType,
user_manager: BaseUserManager[models.UP, models.ID],
) -> Optional[models.UP]:
try:
data = decode_jwt(
token,
secret,
self.token_audience,
algorithms=[self.algorithm],
)
user_id = data.get("sub")
if user_id is None:
return None
except pyjwt.PyJWTError:
return None

try:
parsed_id = user_manager.parse_id(user_id)
return await user_manager.get(parsed_id)
except (exceptions.UserNotExists, exceptions.InvalidID):
return None


class Authentication:
"""Authentication utility class"""

Expand All @@ -30,12 +107,13 @@ def get_password_hash(cls, password):
"""Get a password hash for a given clear text password string"""
return cls.CRYPT_CTX.hash(password)

def get_jwt_strategy(self) -> JWTStrategy:
def get_jwt_strategy(self) -> DualSecretJWTStrategy:
"""Get JWT strategy for authentication backend"""
return JWTStrategy(
return DualSecretJWTStrategy(
secret=self._settings.secret_key,
algorithm=self._settings.algorithm,
lifetime_seconds=self._settings.access_token_expire_seconds,
unified_secret=self._settings.unified_secret,
)

def get_user_authentication_backend(self):
Expand Down
4 changes: 1 addition & 3 deletions api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,18 @@
from pydantic_settings import BaseSettings


# pylint: disable=too-few-public-methods
class AuthSettings(BaseSettings):
"""Authentication settings"""

secret_key: str
unified_secret: str = ""
algorithm: str = "HS256"
# Set to None so tokens don't expire
access_token_expire_seconds: float = 315360000
invite_token_expire_seconds: int = 60 * 60 * 24 * 7 # 7 days
public_base_url: str | None = None


# pylint: disable=too-few-public-methods
class PubSubSettings(BaseSettings):
"""Pub/Sub settings loaded from the environment"""

Expand All @@ -35,7 +34,6 @@ class PubSubSettings(BaseSettings):
subscriber_state_ttl_days: int = 30 # Cleanup unused subscriber states


# pylint: disable=too-few-public-methods
class EmailSettings(BaseSettings):
"""Email settings"""

Expand Down
2 changes: 1 addition & 1 deletion api/email_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .config import EmailSettings


class EmailSender: # pylint: disable=too-few-public-methods
class EmailSender:
"""Class to send email report using SMTP"""

def __init__(self):
Expand Down
12 changes: 5 additions & 7 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
# Author: Jeny Sadadia <jeny.sadadia@collabora.com>
# Author: Denys Fedoryshchenko <denys.f@collabora.com>

# pylint: disable=unused-argument,global-statement,too-many-lines

"""KernelCI API main module"""

import asyncio
Expand Down Expand Up @@ -120,7 +118,7 @@ def _validate_startup_environment():


@asynccontextmanager
async def lifespan(app: FastAPI): # pylint: disable=redefined-outer-name
async def lifespan(app: FastAPI):
"""Lifespan functions for startup and shutdown events"""
await pubsub_startup()
await create_indexes()
Expand All @@ -139,7 +137,7 @@ async def lifespan(app: FastAPI): # pylint: disable=redefined-outer-name
app = FastAPI(lifespan=lifespan, debug=True, docs_url=None, redoc_url=None)
db = Database(service=os.getenv("MONGO_SERVICE", DEFAULT_MONGO_SERVICE))
auth = Authentication(token_url="user/login")
pubsub = None # pylint: disable=invalid-name
pubsub = None

auth_backend = auth.get_user_authentication_backend()
fastapi_users_instance = FastAPIUsers[User, PydanticObjectId](
Expand All @@ -151,7 +149,7 @@ async def lifespan(app: FastAPI): # pylint: disable=redefined-outer-name

async def pubsub_startup():
"""Startup event handler to create Pub/Sub object"""
global pubsub # pylint: disable=invalid-name
global pubsub
pubsub = await PubSub.create()


Expand Down Expand Up @@ -557,7 +555,7 @@ async def invite_user(
invite_url,
)
email_sent = True
except Exception as exc: # pylint: disable=broad-exception-caught
except Exception as exc:
print(f"Failed to send invite email: {exc}")

return UserInviteResponse(
Expand Down Expand Up @@ -640,7 +638,7 @@ async def accept_invite(accept: InviteAcceptRequest):

try:
await user_manager.send_invite_accepted_email(updated_user)
except Exception as exc: # pylint: disable=broad-exception-caught
except Exception as exc:
print(f"Failed to send invite accepted email: {exc}")
return updated_user

Expand Down
22 changes: 8 additions & 14 deletions api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
# Copyright (C) 2023 Collabora Limited
# Author: Jeny Sadadia <jeny.sadadia@collabora.com>

# Disable flag as user models don't require any public methods
# at the moment
# pylint: disable=too-few-public-methods

# pylint: disable=no-name-in-module

"""Server-side model definitions"""

from datetime import datetime
Expand Down Expand Up @@ -120,7 +114,7 @@ class UserGroupCreateRequest(BaseModel):

class User(
BeanieBaseUser,
Document, # pylint: disable=too-many-ancestors
Document,
DatabaseModel,
):
"""API User model"""
Expand All @@ -131,7 +125,7 @@ class User(
)

@field_validator("groups")
def validate_groups(cls, groups): # pylint: disable=no-self-argument
def validate_groups(cls, groups):
"""Unique group constraint"""
unique_names = {group.name for group in groups}
if len(unique_names) != len(groups):
Expand Down Expand Up @@ -159,7 +153,7 @@ class UserRead(schemas.BaseUser[PydanticObjectId], ModelId):
groups: List[UserGroup] = Field(default=[])

@field_validator("groups")
def validate_groups(cls, groups): # pylint: disable=no-self-argument
def validate_groups(cls, groups):
"""Unique group constraint"""
unique_names = {group.name for group in groups}
if len(unique_names) != len(groups):
Expand All @@ -174,7 +168,7 @@ class UserCreateRequest(schemas.BaseUserCreate):
groups: List[str] = Field(default=[])

@field_validator("groups")
def validate_groups(cls, groups): # pylint: disable=no-self-argument
def validate_groups(cls, groups):
"""Unique group constraint"""
unique_names = set(groups)
if len(unique_names) != len(groups):
Expand All @@ -189,7 +183,7 @@ class UserCreate(schemas.BaseUserCreate):
groups: List[UserGroup] = Field(default=[])

@field_validator("groups")
def validate_groups(cls, groups): # pylint: disable=no-self-argument
def validate_groups(cls, groups):
"""Unique group constraint"""
unique_names = {group.name for group in groups}
if len(unique_names) != len(groups):
Expand All @@ -206,7 +200,7 @@ class UserUpdateRequest(schemas.BaseUserUpdate):
groups: List[str] = Field(default=[])

@field_validator("groups")
def validate_groups(cls, groups): # pylint: disable=no-self-argument
def validate_groups(cls, groups):
"""Unique group constraint"""
unique_names = set(groups)
if len(unique_names) != len(groups):
Expand All @@ -223,7 +217,7 @@ class UserUpdate(schemas.BaseUserUpdate):
groups: List[UserGroup] = Field(default=[])

@field_validator("groups")
def validate_groups(cls, groups): # pylint: disable=no-self-argument
def validate_groups(cls, groups):
"""Unique group constraint"""
unique_names = {group.name for group in groups}
if len(unique_names) != len(groups):
Expand All @@ -246,7 +240,7 @@ class UserInviteRequest(BaseModel):
resend_if_exists: bool = False

@field_validator("groups")
def validate_groups(cls, groups): # pylint: disable=no-self-argument
def validate_groups(cls, groups):
"""Unique group constraint"""
unique_names = set(groups)
if len(unique_names) != len(groups):
Expand Down
5 changes: 1 addition & 4 deletions api/pubsub_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Copyright (C) 2025 Collabora Limited
# Author: Denys Fedoryshchenko <denys.f@collabora.com>

# pylint: disable=duplicate-code
# Note: This module intentionally shares interface code with pubsub.py
# as both implement the same PubSub API contract

Expand Down Expand Up @@ -35,7 +34,7 @@
logger = logging.getLogger(__name__)


class PubSub: # pylint: disable=too-many-instance-attributes
class PubSub:
"""Hybrid Pub/Sub implementation with MongoDB durability

Supports two modes:
Expand Down Expand Up @@ -328,7 +327,6 @@ def _eventhistory_to_cloudevent(self, event: Dict) -> str:
ce = CloudEvent(attributes=attributes, data=event.get("data", {}))
return to_json(ce).decode("utf-8")

# pylint: disable=too-many-arguments
async def _get_missed_events(
self,
channel: str,
Expand Down Expand Up @@ -410,7 +408,6 @@ async def subscribe(

return sub

# pylint: disable=too-many-arguments
async def _setup_durable_subscription(
self,
sub_id: int,
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e_tests/listen_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def create_listen_task(test_async_client, subscription_id):
listen_path,
headers={
"Accept": "application/json",
"Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member
"Authorization": f"Bearer {pytest.BEARER_TOKEN}",
},
)
)
Expand Down
10 changes: 5 additions & 5 deletions tests/e2e_tests/test_node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def create_node(test_async_client, node):
"node",
headers={
"Accept": "application/json",
"Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member
"Authorization": f"Bearer {pytest.BEARER_TOKEN}",
},
data=json.dumps(node),
)
Expand All @@ -44,7 +44,7 @@ async def get_node_by_id(test_async_client, node_id):
f"node/{node_id}",
headers={
"Accept": "application/json",
"Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member
"Authorization": f"Bearer {pytest.BEARER_TOKEN}",
},
)
assert response.status_code == 200
Expand All @@ -65,7 +65,7 @@ async def get_node_by_attribute(test_async_client, params):
params=params,
headers={
"Accept": "application/json",
"Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member
"Authorization": f"Bearer {pytest.BEARER_TOKEN}",
},
)
assert response.status_code == 200
Expand All @@ -85,7 +85,7 @@ async def update_node(test_async_client, node):
f"node/{node['id']}",
headers={
"Accept": "application/json",
"Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member
"Authorization": f"Bearer {pytest.BEARER_TOKEN}",
},
data=json.dumps(node),
)
Expand All @@ -104,7 +104,7 @@ async def patch_node(test_async_client, node_id, patch_data):
f"node/{node_id}",
headers={
"Accept": "application/json",
"Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member
"Authorization": f"Bearer {pytest.BEARER_TOKEN}",
},
data=json.dumps(patch_data),
)
Expand Down
Loading
Loading