Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ jobs:
name: Run tests with pytest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
os: ["ubuntu-latest", "windows-latest"]
python-version: ["3.10", "3.11", "3.12", "3.13"]
fail-fast: false

steps:
- uses: actions/checkout@v4
Expand All @@ -25,24 +25,34 @@ jobs:
python-version: ${{ matrix.python-version }}
allow-prereleases: true

- name: Create and activate a virtual environment (Unix)
- name: Install uv
uses: astral-sh/setup-uv@v6

- name: Create virtual environment
run: uv venv .venv

- name: Add venv to PATH (Linux/macOS)
if: runner.os != 'Windows'
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
uv venv .venv
echo "VIRTUAL_ENV=.venv" >> $GITHUB_ENV
echo "$PWD/.venv/bin" >> $GITHUB_PATH

# Install dependencies using uv pip
- name: Add venv to PATH (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: |
"VIRTUAL_ENV=.venv" >> $env:GITHUB_ENV
"$pwd\.venv\Scripts" >> $env:GITHUB_PATH

- name: Install dependencies
run: make install-no-pre-commit

# Run tests with coverage
- name: Run tests under coverage
shell: bash
run: |
coverage run --source=model2vec -m pytest
coverage report

# Upload results to Codecov
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
with:
Expand Down
4 changes: 3 additions & 1 deletion model2vec/persistence/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def _resolve_folder(folder_or_repo_path: Path, token: str | None, force_download
if folder := maybe_get_cached_model_path(str(folder_or_repo_path)):
return folder

folder = Path(huggingface_hub.snapshot_download(str(folder_or_repo_path), repo_type="model", token=token))
folder = Path(
huggingface_hub.snapshot_download(str(folder_or_repo_path.as_posix()), repo_type="model", token=token)
)

return folder

Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,11 @@ dev = [
"ruff",
]

distill = ["torch", "transformers", "scikit-learn", "skeletoken>=0.3.2"]
distill = ["torch", "transformers<5.4.0", "scikit-learn", "skeletoken>=0.3.2"]
onnx = ["onnx", "torch"]
# train also installs inference
train = ["torch", "lightning", "scikit-learn", "skops"]
inference = ["scikit-learn", "skops"]
tokenizer = ["transformers"]
quantization = ["scikit-learn"]

[project.urls]
Expand Down
9 changes: 5 additions & 4 deletions tests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ def test_local_loading(mock_static_model: StaticModel) -> None:
with patch("model2vec.persistence.persistence.maybe_get_cached_model_path") as cache:
# Simulate cache hit
cache.return_value = Path(dir_name)
s = StaticModel.from_pretrained("haha", force_download=True)
s = StaticModel.from_pretrained("my_org/haha", force_download=True)
assert mock_snapshot.call_args[0] == ("my_org/haha",)
assert s.tokens == mock_static_model.tokens
s = StaticModel.from_pretrained("haha", force_download=False)
s = StaticModel.from_pretrained("my_org/haha", force_download=False)
assert s.tokens == mock_static_model.tokens

# Simulate cache miss
cache.return_value = None
s = StaticModel.from_pretrained("haha", force_download=True)
s = StaticModel.from_pretrained("my_org/haha", force_download=True)
assert s.tokens == mock_static_model.tokens
s = StaticModel.from_pretrained("haha", force_download=False)
s = StaticModel.from_pretrained("my_org/haha", force_download=False)
assert s.tokens == mock_static_model.tokens

# Called twice, only when `force_download` is False
Expand Down
20 changes: 9 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest.mock import patch

import pytest
Expand All @@ -16,20 +15,19 @@ def test__get_metadata_from_readme_not_exists() -> None:
assert get_metadata_from_readme(Path("zzz")) == {}


def test__get_metadata_from_readme_mocked_file() -> None:
def test__get_metadata_from_readme_mocked_file(tmp_path: Path) -> None:
"""Test getting metadata from a README."""
with NamedTemporaryFile() as f:
f.write(b"---\nkey: value\n---\n")
f.flush()
assert get_metadata_from_readme(Path(f.name))["key"] == "value"
path = tmp_path / "README.md"
path.write_text("---\nkey: value\n---\n", encoding="utf-8")

assert get_metadata_from_readme(path)["key"] == "value"

def test__get_metadata_from_readme_mocked_file_keys() -> None:

def test__get_metadata_from_readme_mocked_file_keys(tmp_path: Path) -> None:
"""Test getting metadata from a README."""
with NamedTemporaryFile() as f:
f.write(b"")
f.flush()
assert set(get_metadata_from_readme(Path(f.name))) == set()
path = tmp_path / "README.md"
path.write_text("b", encoding="utf-8")
assert set(get_metadata_from_readme(path)) == set()


@pytest.mark.parametrize(
Expand Down
Loading
Loading