Source code for pjkm.core.tasks.configure.setup_git_lfs
"""Configure task: set up Git LFS tracking for ML/data file types."""
from __future__ import annotations
import logging
import subprocess
from pjkm.core.engine.task_context import TaskContext
from pjkm.core.models.task import Phase, TaskResult
from pjkm.core.tasks.base import BaseTask
[docs]
logger = logging.getLogger(__name__)
# File extensions commonly used for ML artefacts and large data files.
[docs]
LFS_TRACKED_PATTERNS: list[str] = [
"*.pt",
"*.pth",
"*.onnx",
"*.safetensors",
"*.h5",
"*.hdf5",
"*.pkl",
"*.pickle",
"*.model",
"*.parquet",
"*.npy",
"*.npz",
]
def _git_lfs_available() -> bool:
"""Return True if the ``git lfs`` command is available on the system."""
try:
subprocess.run(
["git", "lfs", "version"],
capture_output=True,
check=True,
)
except (subprocess.CalledProcessError, FileNotFoundError):
return False
return True
[docs]
class SetupGitLfsTask(BaseTask):
"""Configures Git LFS tracking rules for ML/data file types.
Only runs when the *hf* or *ml* package group is selected. If
``git-lfs`` is not installed the task emits a warning and skips
gracefully rather than failing the pipeline.
"""
[docs]
phase = Phase.CONFIGURE
[docs]
depends_on: list[str] = []
[docs]
description = "Set up Git LFS tracking for large ML file types"
# ------------------------------------------------------------------
# Gate
# ------------------------------------------------------------------
[docs]
def should_run(self, ctx: TaskContext) -> bool:
selected = set(ctx.config.selected_groups)
return bool(selected & {"hf", "ml"})
# ------------------------------------------------------------------
# Execution
# ------------------------------------------------------------------
[docs]
def execute(self, ctx: TaskContext) -> TaskResult:
project_dir = ctx.project_dir
dry_run = ctx.config.dry_run
# 1. Verify git-lfs is available -----------------------------------
if not _git_lfs_available():
logger.warning(
"git-lfs is not installed; skipping LFS setup. "
"Install it via `brew install git-lfs` or "
"https://git-lfs.com to enable large-file tracking."
)
return self.skip_result()
# 2. Append tracking rules to .gitattributes -----------------------
gitattributes_path = project_dir / ".gitattributes"
# Read existing content so we don't duplicate rules.
existing_content = ""
if gitattributes_path.exists():
existing_content = gitattributes_path.read_text()
new_lines: list[str] = []
for pattern in LFS_TRACKED_PATTERNS:
rule = f"{pattern} filter=lfs diff=lfs merge=lfs -text"
if rule not in existing_content:
new_lines.append(rule)
files_created: list[str] = []
files_modified: list[str] = []
if new_lines:
if not dry_run:
# Ensure a trailing newline before our block when appending.
separator = "\n" if existing_content and not existing_content.endswith("\n") else ""
header = "# Git LFS tracking rules (auto-generated by pjkm)\n"
block = separator + header + "\n".join(new_lines) + "\n"
with gitattributes_path.open("a") as fh:
fh.write(block)
if existing_content:
files_modified.append(".gitattributes")
else:
files_created.append(".gitattributes")
# 3. Run `git lfs install` in the project directory ----------------
if not dry_run:
subprocess.run(
["git", "lfs", "install"],
cwd=project_dir,
capture_output=True,
check=True,
)
suffix = " (dry run)" if dry_run else ""
patterns_count = len(new_lines)
return self.success_result(
message=(f"Git LFS configured with {patterns_count} tracking rule(s){suffix}"),
files_created=files_created,
files_modified=files_modified,
)