1
0
mirror of https://github.com/AvengeMedia/DankMaterialShell.git synced 2026-05-06 20:42:07 -04:00

Compare commits

..

1 Commits

Author SHA1 Message Date
Marcus Ramberg
8fad2826b1 ci: add flake check 2025-12-08 16:09:16 +01:00
197 changed files with 10276 additions and 12419 deletions

View File

@@ -10,14 +10,21 @@ jobs:
check-flake: check-flake:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Create GitHub App token
id: app_token
uses: actions/create-github-app-token@v1
with:
app-id: ${{ secrets.APP_ID }}
private-key: ${{ secrets.APP_PRIVATE_KEY }}
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
token: ${{ steps.app_token.outputs.token }}
- name: Install Nix - name: Install Nix
uses: cachix/install-nix-action@v31 uses: cachix/install-nix-action@v31
- name: Check the flake - name: Update vendorHash in flake.nix
run: nix flake check run: nix flake check

View File

@@ -1,19 +1,16 @@
name: Release name: Release
on: on:
workflow_dispatch: push:
inputs: tags:
tag: - 'v*'
description: 'Tag to release (e.g., v1.0.1)'
required: true
type: string
permissions: permissions:
contents: write contents: write
actions: write actions: write
concurrency: concurrency:
group: release-${{ inputs.tag }} group: release-${{ github.ref_name }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@@ -27,14 +24,10 @@ jobs:
run: run:
working-directory: core working-directory: core
env:
TAG: ${{ inputs.tag }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
ref: ${{ inputs.tag }}
fetch-depth: 0 fetch-depth: 0
- name: Set up Go - name: Set up Go
@@ -61,7 +54,7 @@ jobs:
run: | run: |
set -eux set -eux
cd cmd/dankinstall cd cmd/dankinstall
go build -trimpath -ldflags "-s -w -X main.Version=${TAG}" \ go build -trimpath -ldflags "-s -w -X main.Version=${GITHUB_REF#refs/tags/}" \
-o ../../dankinstall-${{ matrix.arch }} -o ../../dankinstall-${{ matrix.arch }}
cd ../.. cd ../..
gzip -9 -k dankinstall-${{ matrix.arch }} gzip -9 -k dankinstall-${{ matrix.arch }}
@@ -75,7 +68,7 @@ jobs:
run: | run: |
set -eux set -eux
cd cmd/dms cd cmd/dms
go build -trimpath -ldflags "-s -w -X main.Version=${TAG}" \ go build -trimpath -ldflags "-s -w -X main.Version=${GITHUB_REF#refs/tags/}" \
-o ../../dms-${{ matrix.arch }} -o ../../dms-${{ matrix.arch }}
cd ../.. cd ../..
gzip -9 -k dms-${{ matrix.arch }} gzip -9 -k dms-${{ matrix.arch }}
@@ -98,7 +91,7 @@ jobs:
run: | run: |
set -eux set -eux
cd cmd/dms cd cmd/dms
go build -trimpath -tags distro_binary -ldflags "-s -w -X main.Version=${TAG}" \ go build -trimpath -tags distro_binary -ldflags "-s -w -X main.Version=${GITHUB_REF#refs/tags/}" \
-o ../../dms-distropkg-${{ matrix.arch }} -o ../../dms-distropkg-${{ matrix.arch }}
cd ../.. cd ../..
gzip -9 -k dms-distropkg-${{ matrix.arch }} gzip -9 -k dms-distropkg-${{ matrix.arch }}
@@ -135,61 +128,60 @@ jobs:
core/completion.zsh core/completion.zsh
if-no-files-found: error if-no-files-found: error
# update-versions: update-versions:
# runs-on: ubuntu-latest runs-on: ubuntu-latest
# needs: build-core needs: build-core
# steps: steps:
# - name: Create GitHub App token - name: Create GitHub App token
# id: app_token id: app_token
# uses: actions/create-github-app-token@v1 uses: actions/create-github-app-token@v1
# with: with:
# app-id: ${{ secrets.APP_ID }} app-id: ${{ secrets.APP_ID }}
# private-key: ${{ secrets.APP_PRIVATE_KEY }} private-key: ${{ secrets.APP_PRIVATE_KEY }}
# - name: Checkout - name: Checkout
# uses: actions/checkout@v4 uses: actions/checkout@v4
# with: with:
# token: ${{ steps.app_token.outputs.token }} token: ${{ steps.app_token.outputs.token }}
# fetch-depth: 0 fetch-depth: 0
# - name: Update VERSION - name: Update VERSION
# env: env:
# GH_TOKEN: ${{ steps.app_token.outputs.token }} GH_TOKEN: ${{ steps.app_token.outputs.token }}
# run: | run: |
# set -euo pipefail set -euo pipefail
# git config user.name "dms-ci[bot]" git config user.name "dms-ci[bot]"
# git config user.email "dms-ci[bot]@users.noreply.github.com" git config user.email "dms-ci[bot]@users.noreply.github.com"
# version="${GITHUB_REF#refs/tags/}" version="${GITHUB_REF#refs/tags/}"
# echo "Updating to version: $version" echo "Updating to version: $version"
# echo "${version}" > quickshell/VERSION echo "${version}" > quickshell/VERSION
# git add quickshell/VERSION git add quickshell/VERSION
# if ! git diff --cached --quiet; then if ! git diff --cached --quiet; then
# git commit -m "chore: bump version to $version" git commit -m "chore: bump version to $version"
# git pull --rebase origin master git pull --rebase origin master
# git push https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git HEAD:master git push https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git HEAD:master
# fi fi
# git tag -f "${version}" git tag -f "${version}"
# git push -f https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git "${version}" git push -f https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git "${version}"
release: release:
runs-on: ubuntu-24.04 runs-on: ubuntu-24.04
needs: [build-core] #, update-versions] needs: [build-core, update-versions]
env: env:
TAG: ${{ inputs.tag }} TAG: ${{ github.ref_name }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
ref: ${{ inputs.tag }}
fetch-depth: 0 fetch-depth: 0
- name: Fetch updated tag after version bump - name: Fetch updated tag after version bump
run: | run: |
git fetch origin --force tag ${TAG} git fetch origin --force tag ${{ github.ref_name }}
git checkout ${TAG} git checkout ${{ github.ref_name }}
- name: Download core artifacts - name: Download core artifacts
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
@@ -399,46 +391,39 @@ jobs:
trigger-obs-update: trigger-obs-update:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: release needs: release
env:
TAG: ${{ inputs.tag }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
ref: ${{ inputs.tag }}
- name: Install OSC - name: Install OSC
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install -y osc sudo apt-get install -y osc
mkdir -p ~/.config/osc mkdir -p ~/.config/osc
cat > ~/.config/osc/oscrc << EOF cat > ~/.config/osc/oscrc << EOF
[general] [general]
apiurl = https://api.opensuse.org apiurl = https://api.opensuse.org
[https://api.opensuse.org] [https://api.opensuse.org]
user = ${{ secrets.OBS_USERNAME }} user = ${{ secrets.OBS_USERNAME }}
pass = ${{ secrets.OBS_PASSWORD }} pass = ${{ secrets.OBS_PASSWORD }}
EOF EOF
chmod 600 ~/.config/osc/oscrc chmod 600 ~/.config/osc/oscrc
- name: Update OBS packages - name: Update OBS packages
run: | run: |
VERSION="${{ github.ref_name }}"
cd distro cd distro
bash scripts/obs-upload.sh dms "Update to ${TAG}" bash scripts/obs-upload.sh dms "Update to $VERSION"
trigger-ppa-update: trigger-ppa-update:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: release needs: release
env:
TAG: ${{ inputs.tag }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
ref: ${{ inputs.tag }}
- name: Install build dependencies - name: Install build dependencies
run: | run: |
sudo apt-get update sudo apt-get update
@@ -450,7 +435,7 @@ jobs:
build-essential \ build-essential \
fakeroot \ fakeroot \
dpkg-dev dpkg-dev
- name: Configure GPG - name: Configure GPG
env: env:
GPG_KEY: ${{ secrets.GPG_PRIVATE_KEY }} GPG_KEY: ${{ secrets.GPG_PRIVATE_KEY }}
@@ -458,9 +443,10 @@ jobs:
echo "$GPG_KEY" | gpg --import echo "$GPG_KEY" | gpg --import
GPG_KEY_ID=$(gpg --list-secret-keys --keyid-format LONG | grep sec | awk '{print $2}' | cut -d'/' -f2) GPG_KEY_ID=$(gpg --list-secret-keys --keyid-format LONG | grep sec | awk '{print $2}' | cut -d'/' -f2)
echo "DEBSIGN_KEYID=$GPG_KEY_ID" >> $GITHUB_ENV echo "DEBSIGN_KEYID=$GPG_KEY_ID" >> $GITHUB_ENV
- name: Upload to PPA - name: Upload to PPA
run: | run: |
VERSION="${{ github.ref_name }}"
cd distro/ubuntu/ppa cd distro/ubuntu/ppa
bash create-and-upload.sh ../dms dms questing bash create-and-upload.sh ../dms dms questing
@@ -468,13 +454,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: release needs: release
env: env:
TAG: ${{ inputs.tag }} TAG: ${{ github.ref_name }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
ref: ${{ inputs.tag }}
- name: Determine version - name: Determine version
id: version id: version
@@ -527,7 +511,7 @@ jobs:
Requires: (quickshell or quickshell-git) Requires: (quickshell or quickshell-git)
Requires: accountsservice Requires: accountsservice
Requires: dms-cli = %{version}-%{release} Requires: dms-cli
Requires: dgop Requires: dgop
Recommends: cava Recommends: cava
@@ -557,6 +541,17 @@ jobs:
Command-line interface for DankMaterialShell configuration and management. Command-line interface for DankMaterialShell configuration and management.
Provides native DBus bindings, NetworkManager integration, and system utilities. Provides native DBus bindings, NetworkManager integration, and system utilities.
%package -n dgop
Summary: Stateless CPU/GPU monitor for DankMaterialShell
License: MIT
URL: https://github.com/AvengeMedia/dgop
Provides: dgop
%description -n dgop
DGOP is a stateless system monitoring tool that provides CPU, GPU, memory, and
network statistics. Designed for integration with DankMaterialShell but can be
used standalone. This package always includes the latest stable dgop release.
%prep %prep
%setup -q -c -n dms-qml %setup -q -c -n dms-qml
@@ -581,10 +576,18 @@ jobs:
gunzip -c %{_builddir}/dms-cli.gz > %{_builddir}/dms-cli gunzip -c %{_builddir}/dms-cli.gz > %{_builddir}/dms-cli
chmod +x %{_builddir}/dms-cli chmod +x %{_builddir}/dms-cli
wget -O %{_builddir}/dgop.gz "https://github.com/AvengeMedia/dgop/releases/latest/download/dgop-linux-${ARCH_SUFFIX}.gz" || {
echo "Failed to download dgop for architecture %{_arch}"
exit 1
}
gunzip -c %{_builddir}/dgop.gz > %{_builddir}/dgop
chmod +x %{_builddir}/dgop
%build %build
%install %install
install -Dm755 %{_builddir}/dms-cli %{buildroot}%{_bindir}/dms install -Dm755 %{_builddir}/dms-cli %{buildroot}%{_bindir}/dms
install -Dm755 %{_builddir}/dgop %{buildroot}%{_bindir}/dgop
install -d %{buildroot}%{_datadir}/bash-completion/completions install -d %{buildroot}%{_datadir}/bash-completion/completions
install -d %{buildroot}%{_datadir}/zsh/site-functions install -d %{buildroot}%{_datadir}/zsh/site-functions
@@ -614,8 +617,10 @@ jobs:
rmdir "%{_sysconfdir}/xdg/quickshell" 2>/dev/null || true rmdir "%{_sysconfdir}/xdg/quickshell" 2>/dev/null || true
rmdir "%{_sysconfdir}/xdg" 2>/dev/null || true rmdir "%{_sysconfdir}/xdg" 2>/dev/null || true
fi fi
# Signal running DMS instances to reload
pkill -USR1 -x dms >/dev/null 2>&1 || : if [ "$1" -ge 2 ]; then
pkill -USR1 -x dms >/dev/null 2>&1 || true
fi
%files %files
%license LICENSE %license LICENSE
@@ -631,10 +636,14 @@ jobs:
%{_datadir}/zsh/site-functions/_dms %{_datadir}/zsh/site-functions/_dms
%{_datadir}/fish/vendor_completions.d/dms.fish %{_datadir}/fish/vendor_completions.d/dms.fish
%files -n dgop
%{_bindir}/dgop
%changelog %changelog
* CHANGELOG_DATE_PLACEHOLDER AvengeMedia <contact@avengemedia.com> - VERSION_PLACEHOLDER-1 * CHANGELOG_DATE_PLACEHOLDER AvengeMedia <contact@avengemedia.com> - VERSION_PLACEHOLDER-1
- Stable release VERSION_PLACEHOLDER - Stable release VERSION_PLACEHOLDER
- Built from GitHub release - Built from GitHub release
- Includes latest dms-cli and dgop binaries
SPECEOF SPECEOF
sed -i "s/VERSION_PLACEHOLDER/${VERSION}/g" ~/rpmbuild/SPECS/dms.spec sed -i "s/VERSION_PLACEHOLDER/${VERSION}/g" ~/rpmbuild/SPECS/dms.spec

View File

@@ -62,7 +62,7 @@ jobs:
} }
echo "✅ Source downloaded" echo "✅ Source downloaded"
echo "Note: dms-cli binary will be downloaded during build based on target architecture" echo "Note: dms-cli and dgop binaries will be downloaded during build based on target architecture"
ls -lh ls -lh
- name: Generate stable spec file - name: Generate stable spec file
@@ -94,7 +94,7 @@ jobs:
Requires: (quickshell or quickshell-git) Requires: (quickshell or quickshell-git)
Requires: accountsservice Requires: accountsservice
Requires: dms-cli = %{version}-%{release} Requires: dms-cli
Requires: dgop Requires: dgop
Recommends: cava Recommends: cava
@@ -125,6 +125,17 @@ jobs:
Command-line interface for DankMaterialShell configuration and management. Command-line interface for DankMaterialShell configuration and management.
Provides native DBus bindings, NetworkManager integration, and system utilities. Provides native DBus bindings, NetworkManager integration, and system utilities.
%package -n dgop
Summary: Stateless CPU/GPU monitor for DankMaterialShell
License: MIT
URL: https://github.com/AvengeMedia/dgop
Provides: dgop
%description -n dgop
DGOP is a stateless system monitoring tool that provides CPU, GPU, memory, and
network statistics. Designed for integration with DankMaterialShell but can be
used standalone. This package always includes the latest stable dgop release.
%prep %prep
%setup -q -c -n dms-qml %setup -q -c -n dms-qml
@@ -151,10 +162,19 @@ jobs:
gunzip -c %{_builddir}/dms-cli.gz > %{_builddir}/dms-cli gunzip -c %{_builddir}/dms-cli.gz > %{_builddir}/dms-cli
chmod +x %{_builddir}/dms-cli chmod +x %{_builddir}/dms-cli
# Download dgop for target architecture
wget -O %{_builddir}/dgop.gz "https://github.com/AvengeMedia/dgop/releases/latest/download/dgop-linux-${ARCH_SUFFIX}.gz" || {
echo "Failed to download dgop for architecture %{_arch}"
exit 1
}
gunzip -c %{_builddir}/dgop.gz > %{_builddir}/dgop
chmod +x %{_builddir}/dgop
%build %build
%install %install
install -Dm755 %{_builddir}/dms-cli %{buildroot}%{_bindir}/dms install -Dm755 %{_builddir}/dms-cli %{buildroot}%{_bindir}/dms
install -Dm755 %{_builddir}/dgop %{buildroot}%{_bindir}/dgop
# Shell completions # Shell completions
install -d %{buildroot}%{_datadir}/bash-completion/completions install -d %{buildroot}%{_datadir}/bash-completion/completions
@@ -182,8 +202,11 @@ jobs:
rmdir "%{_sysconfdir}/xdg/quickshell" 2>/dev/null || true rmdir "%{_sysconfdir}/xdg/quickshell" 2>/dev/null || true
rmdir "%{_sysconfdir}/xdg" 2>/dev/null || true rmdir "%{_sysconfdir}/xdg" 2>/dev/null || true
fi fi
# Signal running DMS instances to reload (harmless if none running)
pkill -USR1 -x dms >/dev/null 2>&1 || : # Restart DMS for active users after upgrade
if [ "$1" -ge 2 ]; then
pkill -USR1 -x dms >/dev/null 2>&1 || true
fi
%files %files
%license LICENSE %license LICENSE
@@ -197,10 +220,14 @@ jobs:
%{_datadir}/zsh/site-functions/_dms %{_datadir}/zsh/site-functions/_dms
%{_datadir}/fish/vendor_completions.d/dms.fish %{_datadir}/fish/vendor_completions.d/dms.fish
%files -n dgop
%{_bindir}/dgop
%changelog %changelog
* CHANGELOG_DATE_PLACEHOLDER AvengeMedia <contact@avengemedia.com> - VERSION_PLACEHOLDER-RELEASE_PLACEHOLDER * CHANGELOG_DATE_PLACEHOLDER AvengeMedia <contact@avengemedia.com> - VERSION_PLACEHOLDER-RELEASE_PLACEHOLDER
- Stable release VERSION_PLACEHOLDER - Stable release VERSION_PLACEHOLDER
- Built from GitHub release - Built from GitHub release
- Includes latest dms-cli and dgop binaries
SPECEOF SPECEOF
sed -i "s/VERSION_PLACEHOLDER/${VERSION}/g" ~/rpmbuild/SPECS/dms.spec sed -i "s/VERSION_PLACEHOLDER/${VERSION}/g" ~/rpmbuild/SPECS/dms.spec

View File

@@ -5,21 +5,21 @@
<img src="assets/danklogo.svg" alt="DankMaterialShell" width="200"> <img src="assets/danklogo.svg" alt="DankMaterialShell" width="200">
</a> </a>
### A modern desktop shell for Wayland ### A modern desktop shell for Wayland
Built with [Quickshell](https://quickshell.org/) and [Go](https://go.dev/) Built with [Quickshell](https://quickshell.org/) and [Go](https://go.dev/)
[![Documentation](https://img.shields.io/badge/docs-danklinux.com-9ccbfb?style=for-the-badge&labelColor=101418)](https://danklinux.com/docs) [![Documentation](https://img.shields.io/badge/docs-danklinux.com-9ccbfb?style=for-the-badge&labelColor=101418)](https://danklinux.com/docs)
[![GitHub stars](https://img.shields.io/github/stars/AvengeMedia/DankMaterialShell?style=for-the-badge&labelColor=101418&color=ffd700)](https://github.com/AvengeMedia/DankMaterialShell/stargazers) [![GitHub stars](https://img.shields.io/github/stars/AvengeMedia/DankMaterialShell?style=for-the-badge&labelColor=101418&color=ffd700)](https://github.com/AvengeMedia/DankMaterialShell/stargazers)
[![GitHub License](https://img.shields.io/github/license/AvengeMedia/DankMaterialShell?style=for-the-badge&labelColor=101418&color=b9c8da)](https://github.com/AvengeMedia/DankMaterialShell/blob/master/LICENSE) [![GitHub License](https://img.shields.io/github/license/AvengeMedia/DankMaterialShell?style=for-the-badge&labelColor=101418&color=b9c8da)](https://github.com/AvengeMedia/DankMaterialShell/blob/master/LICENSE)
[![GitHub release](https://img.shields.io/github/v/release/AvengeMedia/DankMaterialShell?style=for-the-badge&labelColor=101418&color=9ccbfb)](https://github.com/AvengeMedia/DankMaterialShell/releases) [![GitHub release](https://img.shields.io/github/v/release/AvengeMedia/DankMaterialShell?style=for-the-badge&labelColor=101418&color=9ccbfb)](https://github.com/AvengeMedia/DankMaterialShell/releases)
[![AUR version](https://img.shields.io/aur/version/dms-shell-bin?style=for-the-badge&labelColor=101418&color=9ccbfb)](https://aur.archlinux.org/packages/dms-shell-bin) [![AUR version](https://img.shields.io/aur/version/dms-shell-bin?style=for-the-badge&labelColor=101418&color=9ccbfb)](https://aur.archlinux.org/packages/dms-shell-bin)
[![AUR version (git)](<https://img.shields.io/aur/version/dms-shell-git?style=for-the-badge&labelColor=101418&color=9ccbfb&label=AUR%20(git)>)](https://aur.archlinux.org/packages/dms-shell-git) [![AUR version (git)](https://img.shields.io/aur/version/dms-shell-git?style=for-the-badge&labelColor=101418&color=9ccbfb&label=AUR%20(git))](https://aur.archlinux.org/packages/dms-shell-git)
[![Ko-Fi donate](https://img.shields.io/badge/donate-kofi?style=for-the-badge&logo=ko-fi&logoColor=ffffff&label=ko-fi&labelColor=101418&color=f16061&link=https%3A%2F%2Fko-fi.com%2Fdanklinux)](https://ko-fi.com/danklinux) [![Ko-Fi donate](https://img.shields.io/badge/donate-kofi?style=for-the-badge&logo=ko-fi&logoColor=ffffff&label=ko-fi&labelColor=101418&color=f16061&link=https%3A%2F%2Fko-fi.com%2Fdanklinux)](https://ko-fi.com/danklinux)
</div> </div>
DankMaterialShell is a complete desktop shell for [niri](https://github.com/YaLTeR/niri), [Hyprland](https://hyprland.org/), [MangoWC](https://github.com/DreamMaoMao/mangowc), [Sway](https://swaywm.org), [labwc](https://labwc.github.io/), [Scroll](https://github.com/dawsers/scroll), and other Wayland compositors. It replaces waybar, swaylock, swayidle, mako, fuzzel, polkit, and everything else you'd normally stitch together to make a desktop. DankMaterialShell is a complete desktop shell for [niri](https://github.com/YaLTeR/niri), [Hyprland](https://hyprland.org/), [MangoWC](https://github.com/DreamMaoMao/mangowc), [Sway](https://swaywm.org), [labwc](https://labwc.github.io/), and other Wayland compositors. It replaces waybar, swaylock, swayidle, mako, fuzzel, polkit, and everything else you'd normally stitch together to make a desktop.
## Repository Structure ## Repository Structure
@@ -105,7 +105,7 @@ Extend functionality with the [plugin registry](https://plugins.danklinux.com).
## Supported Compositors ## Supported Compositors
Works best with [niri](https://github.com/YaLTeR/niri), [Hyprland](https://hyprland.org/), [Sway](https://swaywm.org/), [MangoWC](https://github.com/DreamMaoMao/mangowc), [labwc](https://labwc.github.io/), and [Scroll](https://github.com/dawsers/scroll) with full workspace switching, overview integration, and monitor management. Other Wayland compositors work with reduced features. Works best with [niri](https://github.com/YaLTeR/niri), [Hyprland](https://hyprland.org/), [Sway](https://swaywm.org/), [MangoWC](https://github.com/DreamMaoMao/mangowc), and [labwc](https://labwc.github.io/) with full workspace switching, overview integration, and monitor management. Other Wayland compositors work with reduced features.
[Compositor configuration guide](https://danklinux.com/docs/dankmaterialshell/compositors) [Compositor configuration guide](https://danklinux.com/docs/dankmaterialshell/compositors)
@@ -127,7 +127,7 @@ dms plugins search # Browse plugin registry
## Documentation ## Documentation
- **Website:** [danklinux.com](https://danklinux.com) - **Website:** [danklinux.com](https://danklinux.com)
- **Docs:** [danklinux.com/docs](https://danklinux.com/docs/) - **Docs:** [danklinux.com/docs](https://danklinux.com/docs)
- **Theming:** [Application themes](https://danklinux.com/docs/dankmaterialshell/application-themes) | [Custom themes](https://danklinux.com/docs/dankmaterialshell/custom-themes) - **Theming:** [Application themes](https://danklinux.com/docs/dankmaterialshell/application-themes) | [Custom themes](https://danklinux.com/docs/dankmaterialshell/custom-themes)
- **Plugins:** [Development guide](https://danklinux.com/docs/dankmaterialshell/plugins-overview) - **Plugins:** [Development guide](https://danklinux.com/docs/dankmaterialshell/plugins-overview)
- **Support:** [Ko-fi](https://ko-fi.com/avengemediallc) - **Support:** [Ko-fi](https://ko-fi.com/avengemediallc)
@@ -143,7 +143,6 @@ See component-specific documentation:
### Building from Source ### Building from Source
**Core + Dankinstall:** **Core + Dankinstall:**
```bash ```bash
cd core cd core
make # Build dms CLI make # Build dms CLI
@@ -151,13 +150,11 @@ make dankinstall # Build installer
``` ```
**Shell:** **Shell:**
```bash ```bash
quickshell -p quickshell/ quickshell -p quickshell/
``` ```
**NixOS:** **NixOS:**
```nix ```nix
{ {
inputs.dms.url = "github:AvengeMedia/DankMaterialShell"; inputs.dms.url = "github:AvengeMedia/DankMaterialShell";

1
alejandra.toml Normal file
View File

@@ -0,0 +1 @@
indentation = "FourSpaces"

View File

@@ -22,8 +22,6 @@ linters:
- (*os.Process).Signal - (*os.Process).Signal
- (*os.Process).Kill - (*os.Process).Kill
- syscall.Kill - syscall.Kill
# Seek on memfd (reset position before passing fd)
- syscall.Seek
# DBus cleanup # DBus cleanup
- (*github.com/godbus/dbus/v5.Conn).RemoveMatchSignal - (*github.com/godbus/dbus/v5.Conn).RemoveMatchSignal
- (*github.com/godbus/dbus/v5.Conn).RemoveSignal - (*github.com/godbus/dbus/v5.Conn).RemoveSignal

View File

@@ -12,11 +12,6 @@ import (
var Version = "dev" var Version = "dev"
func main() { func main() {
if os.Getuid() == 0 {
fmt.Fprintln(os.Stderr, "Error: dankinstall must not be run as root")
os.Exit(1)
}
fileLogger, err := log.NewFileLogger() fileLogger, err := log.NewFileLogger()
if err != nil { if err != nil {
fmt.Printf("Warning: Failed to create log file: %v\n", err) fmt.Printf("Warning: Failed to create log file: %v\n", err)

View File

@@ -211,13 +211,45 @@ func runBrightnessSet(cmd *cobra.Command, args []string) {
exponential, _ := cmd.Flags().GetBool("exponential") exponential, _ := cmd.Flags().GetBool("exponential")
exponent, _ := cmd.Flags().GetFloat64("exponent") exponent, _ := cmd.Flags().GetFloat64("exponent")
// For backlight/leds devices, try logind backend first (requires D-Bus connection)
parts := strings.SplitN(deviceID, ":", 2) parts := strings.SplitN(deviceID, ":", 2)
if len(parts) == 2 && (parts[0] == "backlight" || parts[0] == "leds") { if len(parts) == 2 && (parts[0] == "backlight" || parts[0] == "leds") {
if ok := tryLogindBrightness(parts[0], parts[1], deviceID, percent, exponential, exponent); ok { subsystem := parts[0]
return name := parts[1]
// Initialize backends needed for logind approach
sysfs, err := brightness.NewSysfsBackend()
if err != nil {
log.Debugf("NewSysfsBackend failed: %v", err)
} else {
logind, err := brightness.NewLogindBackend()
if err != nil {
log.Debugf("NewLogindBackend failed: %v", err)
} else {
defer logind.Close()
// Get device info to convert percent to value
dev, err := sysfs.GetDevice(deviceID)
if err == nil {
// Calculate hardware value using the same logic as Manager.setViaSysfsWithLogind
value := sysfs.PercentToValueWithExponent(percent, dev, exponential, exponent)
// Call logind with hardware value
if err := logind.SetBrightness(subsystem, name, uint32(value)); err == nil {
log.Debugf("set %s to %d%% (%d) via logind", deviceID, percent, value)
fmt.Printf("Set %s to %d%%\n", deviceID, percent)
return
} else {
log.Debugf("logind.SetBrightness failed: %v", err)
}
} else {
log.Debugf("sysfs.GetDeviceByID failed: %v", err)
}
}
} }
} }
// Fallback to direct sysfs (requires write permissions)
sysfs, err := brightness.NewSysfsBackend() sysfs, err := brightness.NewSysfsBackend()
if err == nil { if err == nil {
if err := sysfs.SetBrightnessWithExponent(deviceID, percent, exponential, exponent); err == nil { if err := sysfs.SetBrightnessWithExponent(deviceID, percent, exponential, exponent); err == nil {
@@ -248,37 +280,6 @@ func runBrightnessSet(cmd *cobra.Command, args []string) {
log.Fatalf("Failed to set brightness for device: %s", deviceID) log.Fatalf("Failed to set brightness for device: %s", deviceID)
} }
func tryLogindBrightness(subsystem, name, deviceID string, percent int, exponential bool, exponent float64) bool {
sysfs, err := brightness.NewSysfsBackend()
if err != nil {
log.Debugf("NewSysfsBackend failed: %v", err)
return false
}
logind, err := brightness.NewLogindBackend()
if err != nil {
log.Debugf("NewLogindBackend failed: %v", err)
return false
}
defer logind.Close()
dev, err := sysfs.GetDevice(deviceID)
if err != nil {
log.Debugf("sysfs.GetDeviceByID failed: %v", err)
return false
}
value := sysfs.PercentToValueWithExponent(percent, dev, exponential, exponent)
if err := logind.SetBrightness(subsystem, name, uint32(value)); err != nil {
log.Debugf("logind.SetBrightness failed: %v", err)
return false
}
log.Debugf("set %s to %d%% (%d) via logind", deviceID, percent, value)
fmt.Printf("Set %s to %d%%\n", deviceID, percent)
return true
}
func getBrightnessDevices(includeDDC bool) []string { func getBrightnessDevices(includeDDC bool) []string {
allDevices := getAllBrightnessDevices(includeDDC) allDevices := getAllBrightnessDevices(includeDDC)

View File

@@ -152,24 +152,6 @@ var pluginsUninstallCmd = &cobra.Command{
}, },
} }
var pluginsUpdateCmd = &cobra.Command{
Use: "update <plugin-id>",
Short: "Update a plugin by ID",
Long: "Update an installed DMS plugin using its ID (e.g., 'myPlugin'). Plugin names are also supported.",
Args: cobra.ExactArgs(1),
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) != 0 {
return nil, cobra.ShellCompDirectiveNoFileComp
}
return getInstalledPluginIDs(), cobra.ShellCompDirectiveNoFileComp
},
Run: func(cmd *cobra.Command, args []string) {
if err := updatePluginCLI(args[0]); err != nil {
log.Fatalf("Error updating plugin: %v", err)
}
},
}
func runVersion(cmd *cobra.Command, args []string) { func runVersion(cmd *cobra.Command, args []string) {
printASCII() printASCII()
fmt.Printf("%s\n", formatVersion(Version)) fmt.Printf("%s\n", formatVersion(Version))
@@ -426,70 +408,49 @@ func uninstallPluginCLI(idOrName string) error {
return fmt.Errorf("failed to create registry: %w", err) return fmt.Errorf("failed to create registry: %w", err)
} }
pluginList, _ := registry.List() pluginList, err := registry.List()
plugin := plugins.FindByIDOrName(idOrName, pluginList)
if plugin != nil {
installed, err := manager.IsInstalled(*plugin)
if err != nil {
return fmt.Errorf("failed to check install status: %w", err)
}
if !installed {
return fmt.Errorf("plugin not installed: %s", plugin.Name)
}
fmt.Printf("Uninstalling plugin: %s (ID: %s)\n", plugin.Name, plugin.ID)
if err := manager.Uninstall(*plugin); err != nil {
return fmt.Errorf("failed to uninstall plugin: %w", err)
}
fmt.Printf("Plugin uninstalled successfully: %s\n", plugin.Name)
return nil
}
fmt.Printf("Uninstalling plugin: %s\n", idOrName)
if err := manager.UninstallByIDOrName(idOrName); err != nil {
return err
}
fmt.Printf("Plugin uninstalled successfully: %s\n", idOrName)
return nil
}
func updatePluginCLI(idOrName string) error {
manager, err := plugins.NewManager()
if err != nil { if err != nil {
return fmt.Errorf("failed to create manager: %w", err) return fmt.Errorf("failed to list plugins: %w", err)
} }
registry, err := plugins.NewRegistry() // First, try to find by ID (preferred method)
var plugin *plugins.Plugin
for _, p := range pluginList {
if p.ID == idOrName {
plugin = &p
break
}
}
// Fallback to name for backward compatibility
if plugin == nil {
for _, p := range pluginList {
if p.Name == idOrName {
plugin = &p
break
}
}
}
if plugin == nil {
return fmt.Errorf("plugin not found: %s", idOrName)
}
installed, err := manager.IsInstalled(*plugin)
if err != nil { if err != nil {
return fmt.Errorf("failed to create registry: %w", err) return fmt.Errorf("failed to check install status: %w", err)
} }
pluginList, _ := registry.List() if !installed {
plugin := plugins.FindByIDOrName(idOrName, pluginList) return fmt.Errorf("plugin not installed: %s", plugin.Name)
if plugin != nil {
installed, err := manager.IsInstalled(*plugin)
if err != nil {
return fmt.Errorf("failed to check install status: %w", err)
}
if !installed {
return fmt.Errorf("plugin not installed: %s", plugin.Name)
}
fmt.Printf("Updating plugin: %s (ID: %s)\n", plugin.Name, plugin.ID)
if err := manager.Update(*plugin); err != nil {
return fmt.Errorf("failed to update plugin: %w", err)
}
fmt.Printf("Plugin updated successfully: %s\n", plugin.Name)
return nil
} }
fmt.Printf("Updating plugin: %s\n", idOrName) fmt.Printf("Uninstalling plugin: %s (ID: %s)\n", plugin.Name, plugin.ID)
if err := manager.UpdateByIDOrName(idOrName); err != nil { if err := manager.Uninstall(*plugin); err != nil {
return err return fmt.Errorf("failed to uninstall plugin: %w", err)
} }
fmt.Printf("Plugin updated successfully: %s\n", idOrName)
fmt.Printf("Plugin uninstalled successfully: %s\n", plugin.Name)
return nil return nil
} }

View File

@@ -15,7 +15,6 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/distros" "github.com/AvengeMedia/DankMaterialShell/core/internal/distros"
"github.com/AvengeMedia/DankMaterialShell/core/internal/errdefs" "github.com/AvengeMedia/DankMaterialShell/core/internal/errdefs"
"github.com/AvengeMedia/DankMaterialShell/core/internal/log" "github.com/AvengeMedia/DankMaterialShell/core/internal/log"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
"github.com/AvengeMedia/DankMaterialShell/core/internal/version" "github.com/AvengeMedia/DankMaterialShell/core/internal/version"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@@ -122,10 +121,10 @@ func updateArchLinux() error {
var helper string var helper string
var updateCmd *exec.Cmd var updateCmd *exec.Cmd
if utils.CommandExists("yay") { if commandExists("yay") {
helper = "yay" helper = "yay"
updateCmd = exec.Command("yay", "-S", packageName) updateCmd = exec.Command("yay", "-S", packageName)
} else if utils.CommandExists("paru") { } else if commandExists("paru") {
helper = "paru" helper = "paru"
updateCmd = exec.Command("paru", "-S", packageName) updateCmd = exec.Command("paru", "-S", packageName)
} else { } else {

View File

@@ -10,7 +10,6 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/greeter" "github.com/AvengeMedia/DankMaterialShell/core/internal/greeter"
"github.com/AvengeMedia/DankMaterialShell/core/internal/log" "github.com/AvengeMedia/DankMaterialShell/core/internal/log"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/text/cases" "golang.org/x/text/cases"
"golang.org/x/text/language" "golang.org/x/text/language"
@@ -449,7 +448,7 @@ func enableGreeter() error {
fmt.Println("Detecting installed compositors...") fmt.Println("Detecting installed compositors...")
compositors := greeter.DetectCompositors() compositors := greeter.DetectCompositors()
if utils.CommandExists("sway") { if commandExists("sway") {
compositors = append(compositors, "sway") compositors = append(compositors, "sway")
} }

View File

@@ -89,11 +89,6 @@ func initializeProviders() {
log.Warnf("Failed to register MangoWC provider: %v", err) log.Warnf("Failed to register MangoWC provider: %v", err)
} }
scrollProvider := providers.NewSwayProvider("$HOME/.config/scroll")
if err := registry.Register(scrollProvider); err != nil {
log.Warnf("Failed to register Scroll provider: %v", err)
}
swayProvider := providers.NewSwayProvider("$HOME/.config/sway") swayProvider := providers.NewSwayProvider("$HOME/.config/sway")
if err := registry.Register(swayProvider); err != nil { if err := registry.Register(swayProvider); err != nil {
log.Warnf("Failed to register Sway provider: %v", err) log.Warnf("Failed to register Sway provider: %v", err)
@@ -130,8 +125,6 @@ func makeProviderWithPath(name, path string) keybinds.Provider {
return providers.NewMangoWCProvider(path) return providers.NewMangoWCProvider(path)
case "sway": case "sway":
return providers.NewSwayProvider(path) return providers.NewSwayProvider(path)
case "scroll":
return providers.NewSwayProvider(path)
case "niri": case "niri":
return providers.NewNiriProvider(path) return providers.NewNiriProvider(path)
default: default:

View File

@@ -295,14 +295,7 @@ func bufferToRGBThumbnail(buf *screenshot.ShmBuffer, maxSize int, pixelFormat ui
data := buf.Data() data := buf.Data()
rgb := make([]byte, dstW*dstH*3) rgb := make([]byte, dstW*dstH*3)
swapRB := pixelFormat == uint32(screenshot.FormatARGB8888) || pixelFormat == uint32(screenshot.FormatXRGB8888) || pixelFormat == 0
var swapRB bool
switch pixelFormat {
case uint32(screenshot.FormatABGR8888), uint32(screenshot.FormatXBGR8888):
swapRB = false
default:
swapRB = true
}
for y := 0; y < dstH; y++ { for y := 0; y < dstH; y++ {
srcY := int(float64(y) / scale) srcY := int(float64(y) / scale)
@@ -316,17 +309,16 @@ func bufferToRGBThumbnail(buf *screenshot.ShmBuffer, maxSize int, pixelFormat ui
} }
si := srcY*buf.Stride + srcX*4 si := srcY*buf.Stride + srcX*4
di := (y*dstW + x) * 3 di := (y*dstW + x) * 3
if si+3 >= len(data) { if si+2 < len(data) {
continue if swapRB {
} rgb[di+0] = data[si+2]
if swapRB { rgb[di+1] = data[si+1]
rgb[di+0] = data[si+2] rgb[di+2] = data[si+0]
rgb[di+1] = data[si+1] } else {
rgb[di+2] = data[si+0] rgb[di+0] = data[si+0]
} else { rgb[di+1] = data[si+1]
rgb[di+0] = data[si+0] rgb[di+2] = data[si+2]
rgb[di+1] = data[si+1] }
rgb[di+2] = data[si+2]
} }
} }
} }
@@ -378,37 +370,7 @@ func runScreenshotList(cmd *cobra.Command, args []string) {
} }
for _, o := range outputs { for _, o := range outputs {
scaleStr := fmt.Sprintf("%.2f", o.FractionalScale) fmt.Printf("%s: %dx%d+%d+%d (scale: %d)\n",
if o.FractionalScale == float64(int(o.FractionalScale)) { o.Name, o.Width, o.Height, o.X, o.Y, o.Scale)
scaleStr = fmt.Sprintf("%d", int(o.FractionalScale))
}
transformStr := transformName(o.Transform)
fmt.Printf("%s: %dx%d+%d+%d scale=%s transform=%s\n",
o.Name, o.Width, o.Height, o.X, o.Y, scaleStr, transformStr)
}
}
func transformName(t int32) string {
switch t {
case 0:
return "normal"
case 1:
return "90"
case 2:
return "180"
case 3:
return "270"
case 4:
return "flipped"
case 5:
return "flipped-90"
case 6:
return "flipped-180"
case 7:
return "flipped-270"
default:
return fmt.Sprintf("%d", t)
} }
} }

View File

@@ -29,7 +29,6 @@ func runSetup() error {
wm, wmSelected := promptCompositor() wm, wmSelected := promptCompositor()
terminal, terminalSelected := promptTerminal() terminal, terminalSelected := promptTerminal()
useSystemd := promptSystemd()
if !wmSelected && !terminalSelected { if !wmSelected && !terminalSelected {
fmt.Println("No configurations selected. Exiting.") fmt.Println("No configurations selected. Exiting.")
@@ -68,14 +67,14 @@ func runSetup() error {
var err error var err error
if wmSelected && terminalSelected { if wmSelected && terminalSelected {
results, err = deployer.DeployConfigurationsWithSystemd(ctx, wm, terminal, useSystemd) results, err = deployer.DeployConfigurationsWithTerminal(ctx, wm, terminal)
} else if wmSelected { } else if wmSelected {
results, err = deployer.DeployConfigurationsWithSystemd(ctx, wm, deps.TerminalGhostty, useSystemd) results, err = deployer.DeployConfigurationsWithTerminal(ctx, wm, deps.TerminalGhostty)
if len(results) > 1 { if len(results) > 1 {
results = results[:1] results = results[:1]
} }
} else if terminalSelected { } else if terminalSelected {
results, err = deployer.DeployConfigurationsWithSystemd(ctx, deps.WindowManagerNiri, terminal, useSystemd) results, err = deployer.DeployConfigurationsWithTerminal(ctx, deps.WindowManagerNiri, terminal)
if len(results) > 0 && results[0].ConfigType == "Niri" { if len(results) > 0 && results[0].ConfigType == "Niri" {
results = results[1:] results = results[1:]
} }
@@ -145,19 +144,6 @@ func promptTerminal() (deps.Terminal, bool) {
} }
} }
func promptSystemd() bool {
fmt.Println("\nUse systemd for session management?")
fmt.Println("1) Yes (recommended for most distros)")
fmt.Println("2) No (standalone, no systemd integration)")
var response string
fmt.Print("\nChoice (1-2): ")
fmt.Scanln(&response)
response = strings.TrimSpace(response)
return response != "2"
}
func checkExistingConfigs(wm deps.WindowManager, wmSelected bool, terminal deps.Terminal, terminalSelected bool) bool { func checkExistingConfigs(wm deps.WindowManager, wmSelected bool, terminal deps.Terminal, terminalSelected bool) bool {
homeDir := os.Getenv("HOME") homeDir := os.Getenv("HOME")
willBackup := false willBackup := false

View File

@@ -23,7 +23,7 @@ func init() {
updateCmd.AddCommand(updateCheckCmd) updateCmd.AddCommand(updateCheckCmd)
// Add subcommands to plugins // Add subcommands to plugins
pluginsCmd.AddCommand(pluginsBrowseCmd, pluginsListCmd, pluginsInstallCmd, pluginsUninstallCmd, pluginsUpdateCmd) pluginsCmd.AddCommand(pluginsBrowseCmd, pluginsListCmd, pluginsInstallCmd, pluginsUninstallCmd)
// Add common commands to root // Add common commands to root
rootCmd.AddCommand(getCommonCommands()...) rootCmd.AddCommand(getCommonCommands()...)

View File

@@ -21,7 +21,7 @@ func init() {
greeterCmd.AddCommand(greeterSyncCmd, greeterEnableCmd, greeterStatusCmd) greeterCmd.AddCommand(greeterSyncCmd, greeterEnableCmd, greeterStatusCmd)
// Add subcommands to plugins // Add subcommands to plugins
pluginsCmd.AddCommand(pluginsBrowseCmd, pluginsListCmd, pluginsInstallCmd, pluginsUninstallCmd, pluginsUpdateCmd) pluginsCmd.AddCommand(pluginsBrowseCmd, pluginsListCmd, pluginsInstallCmd, pluginsUninstallCmd)
// Add common commands to root // Add common commands to root
rootCmd.AddCommand(getCommonCommands()...) rootCmd.AddCommand(getCommonCommands()...)

View File

@@ -104,6 +104,7 @@ func getAllDMSPIDs() []int {
continue continue
} }
// Check if the child process is still alive
proc, err := os.FindProcess(childPID) proc, err := os.FindProcess(childPID)
if err != nil { if err != nil {
os.Remove(pidFile) os.Remove(pidFile)
@@ -111,15 +112,18 @@ func getAllDMSPIDs() []int {
} }
if err := proc.Signal(syscall.Signal(0)); err != nil { if err := proc.Signal(syscall.Signal(0)); err != nil {
// Process is dead, remove stale PID file
os.Remove(pidFile) os.Remove(pidFile)
continue continue
} }
pids = append(pids, childPID) pids = append(pids, childPID)
// Also get the parent PID from the filename
parentPIDStr := strings.TrimPrefix(entry.Name(), "danklinux-") parentPIDStr := strings.TrimPrefix(entry.Name(), "danklinux-")
parentPIDStr = strings.TrimSuffix(parentPIDStr, ".pid") parentPIDStr = strings.TrimSuffix(parentPIDStr, ".pid")
if parentPID, err := strconv.Atoi(parentPIDStr); err == nil { if parentPID, err := strconv.Atoi(parentPIDStr); err == nil {
// Check if parent is still alive
if parentProc, err := os.FindProcess(parentPID); err == nil { if parentProc, err := os.FindProcess(parentPID); err == nil {
if err := parentProc.Signal(syscall.Signal(0)); err == nil { if err := parentProc.Signal(syscall.Signal(0)); err == nil {
pids = append(pids, parentPID) pids = append(pids, parentPID)
@@ -155,7 +159,6 @@ func runShellInteractive(session bool) {
errChan <- fmt.Errorf("server panic: %v", r) errChan <- fmt.Errorf("server panic: %v", r)
} }
}() }()
server.CLIVersion = Version
if err := server.Start(false); err != nil { if err := server.Start(false); err != nil {
errChan <- fmt.Errorf("server error: %w", err) errChan <- fmt.Errorf("server error: %w", err)
} }
@@ -222,6 +225,7 @@ func runShellInteractive(session bool) {
return return
} }
// All other signals: clean shutdown
log.Infof("\nReceived signal %v, shutting down...", sig) log.Infof("\nReceived signal %v, shutting down...", sig)
cancel() cancel()
cmd.Process.Signal(syscall.SIGTERM) cmd.Process.Signal(syscall.SIGTERM)
@@ -278,6 +282,7 @@ func restartShell() {
} }
func killShell() { func killShell() {
// Get all tracked DMS PIDs from PID files
pids := getAllDMSPIDs() pids := getAllDMSPIDs()
if len(pids) == 0 { if len(pids) == 0 {
@@ -288,12 +293,14 @@ func killShell() {
currentPid := os.Getpid() currentPid := os.Getpid()
uniquePids := make(map[int]bool) uniquePids := make(map[int]bool)
// Deduplicate and filter out current process
for _, pid := range pids { for _, pid := range pids {
if pid != currentPid { if pid != currentPid {
uniquePids[pid] = true uniquePids[pid] = true
} }
} }
// Kill all tracked processes
for pid := range uniquePids { for pid := range uniquePids {
proc, err := os.FindProcess(pid) proc, err := os.FindProcess(pid)
if err != nil { if err != nil {
@@ -301,6 +308,7 @@ func killShell() {
continue continue
} }
// Check if process is still alive before killing
if err := proc.Signal(syscall.Signal(0)); err != nil { if err := proc.Signal(syscall.Signal(0)); err != nil {
continue continue
} }
@@ -312,6 +320,7 @@ func killShell() {
} }
} }
// Clean up any remaining PID files
dir := getRuntimeDir() dir := getRuntimeDir()
entries, err := os.ReadDir(dir) entries, err := os.ReadDir(dir)
if err != nil { if err != nil {
@@ -328,6 +337,7 @@ func killShell() {
func runShellDaemon(session bool) { func runShellDaemon(session bool) {
isSessionManaged = session isSessionManaged = session
// Check if this is the daemon child process by looking for the hidden flag
isDaemonChild := false isDaemonChild := false
for _, arg := range os.Args { for _, arg := range os.Args {
if arg == "--daemon-child" { if arg == "--daemon-child" {

View File

@@ -6,6 +6,12 @@ import (
"strings" "strings"
) )
func commandExists(cmd string) bool {
_, err := exec.LookPath(cmd)
return err == nil
}
// findCommandPath returns the absolute path to a command in PATH
func findCommandPath(cmd string) (string, error) { func findCommandPath(cmd string) (string, error) {
path, err := exec.LookPath(cmd) path, err := exec.LookPath(cmd)
if err != nil { if err != nil {

View File

@@ -30,7 +30,6 @@ type Output struct {
height int32 height int32
scale int32 scale int32
fractionalScale float64 fractionalScale float64
transform int32
} }
type LayerSurface struct { type LayerSurface struct {
@@ -277,7 +276,6 @@ func (p *Picker) setupOutputHandlers(name uint32, output *client.Output) {
if o, ok := p.outputs[name]; ok { if o, ok := p.outputs[name]; ok {
o.x = e.X o.x = e.X
o.y = e.Y o.y = e.Y
o.transform = int32(e.Transform)
} }
p.outputsMu.Unlock() p.outputsMu.Unlock()
}) })
@@ -487,19 +485,8 @@ func (p *Picker) captureForSurface(ls *LayerSurface) {
frame.SetReadyHandler(func(e wlr_screencopy.ZwlrScreencopyFrameV1ReadyEvent) { frame.SetReadyHandler(func(e wlr_screencopy.ZwlrScreencopyFrameV1ReadyEvent) {
ls.state.OnScreencopyReady() ls.state.OnScreencopyReady()
screenBuf := ls.state.ScreenBuffer()
if screenBuf != nil && ls.output.transform != TransformNormal {
invTransform := InverseTransform(ls.output.transform)
transformed, err := screenBuf.ApplyTransform(invTransform)
if err != nil {
log.Error("apply transform failed", "err", err)
} else if transformed != screenBuf {
ls.state.ReplaceScreenBuffer(transformed)
}
}
logicalW, _ := ls.state.LogicalSize() logicalW, _ := ls.state.LogicalSize()
screenBuf = ls.state.ScreenBuffer() screenBuf := ls.state.ScreenBuffer()
if logicalW > 0 && screenBuf != nil { if logicalW > 0 && screenBuf != nil {
ls.output.fractionalScale = float64(screenBuf.Width) / float64(logicalW) ls.output.fractionalScale = float64(screenBuf.Width) / float64(logicalW)
} }

View File

@@ -4,25 +4,10 @@ import "github.com/AvengeMedia/DankMaterialShell/core/internal/wayland/shm"
type ShmBuffer = shm.Buffer type ShmBuffer = shm.Buffer
const (
TransformNormal = shm.TransformNormal
Transform90 = shm.Transform90
Transform180 = shm.Transform180
Transform270 = shm.Transform270
TransformFlipped = shm.TransformFlipped
TransformFlipped90 = shm.TransformFlipped90
TransformFlipped180 = shm.TransformFlipped180
TransformFlipped270 = shm.TransformFlipped270
)
func CreateShmBuffer(width, height, stride int) (*ShmBuffer, error) { func CreateShmBuffer(width, height, stride int) (*ShmBuffer, error) {
return shm.CreateBuffer(width, height, stride) return shm.CreateBuffer(width, height, stride)
} }
func InverseTransform(transform int32) int32 {
return shm.InverseTransform(transform)
}
func GetPixelColor(buf *ShmBuffer, x, y int) Color { func GetPixelColor(buf *ShmBuffer, x, y int) Color {
return GetPixelColorWithFormat(buf, x, y, FormatARGB8888) return GetPixelColorWithFormat(buf, x, y, FormatARGB8888)
} }

View File

@@ -1,7 +1,6 @@
package colorpicker package colorpicker
import ( import (
"fmt"
"math" "math"
"strings" "strings"
"sync" "sync"
@@ -16,8 +15,6 @@ const (
FormatXRGB8888 = shm.FormatXRGB8888 FormatXRGB8888 = shm.FormatXRGB8888
FormatABGR8888 = shm.FormatABGR8888 FormatABGR8888 = shm.FormatABGR8888
FormatXBGR8888 = shm.FormatXBGR8888 FormatXBGR8888 = shm.FormatXBGR8888
FormatRGB888 = shm.FormatRGB888
FormatBGR888 = shm.FormatBGR888
) )
type SurfaceState struct { type SurfaceState struct {
@@ -82,11 +79,6 @@ func (s *SurfaceState) OnScreencopyBuffer(format PixelFormat, width, height, str
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
bpp := format.BytesPerPixel()
if stride < width*bpp {
return fmt.Errorf("invalid stride %d for width %d (bpp=%d)", stride, width, bpp)
}
if s.screenBuf != nil { if s.screenBuf != nil {
s.screenBuf.Close() s.screenBuf.Close()
s.screenBuf = nil s.screenBuf = nil
@@ -98,7 +90,6 @@ func (s *SurfaceState) OnScreencopyBuffer(format PixelFormat, width, height, str
} }
s.screenBuf = buf s.screenBuf = buf
s.screenBuf.Format = format
s.screenFormat = format s.screenFormat = format
return nil return nil
} }
@@ -115,20 +106,6 @@ func (s *SurfaceState) ScreenFormat() PixelFormat {
return s.screenFormat return s.screenFormat
} }
func (s *SurfaceState) ReplaceScreenBuffer(newBuf *ShmBuffer) {
s.mu.Lock()
defer s.mu.Unlock()
if s.screenBuf != nil {
s.screenBuf.Close()
}
s.screenBuf = newBuf
s.screenFormat = newBuf.Format
s.recomputeScale()
s.ensureRenderBuffers()
}
func (s *SurfaceState) OnScreencopyFlags(flags uint32) { func (s *SurfaceState) OnScreencopyFlags(flags uint32) {
s.mu.Lock() s.mu.Lock()
s.yInverted = (flags & 1) != 0 s.yInverted = (flags & 1) != 0
@@ -143,15 +120,6 @@ func (s *SurfaceState) OnScreencopyReady() {
return return
} }
if s.screenFormat.Is24Bit() {
converted, newFormat, err := s.screenBuf.ConvertTo32Bit(s.screenFormat)
if err == nil && converted != s.screenBuf {
s.screenBuf.Close()
s.screenBuf = converted
s.screenFormat = newFormat
}
}
s.recomputeScale() s.recomputeScale()
s.ensureRenderBuffers() s.ensureRenderBuffers()
s.readyForDisplay = true s.readyForDisplay = true
@@ -311,10 +279,10 @@ func (s *SurfaceState) Redraw() *ShmBuffer {
drawMagnifierWithInversion( drawMagnifierWithInversion(
dst.Data(), dst.Stride, dst.Width, dst.Height, dst.Data(), dst.Stride, dst.Width, dst.Height,
s.screenBuf.Data(), s.screenBuf.Stride, s.screenBuf.Width, s.screenBuf.Height, s.screenBuf.Data(), s.screenBuf.Stride, s.screenBuf.Width, s.screenBuf.Height,
px, py, picked, s.yInverted, s.screenFormat, px, py, picked, s.yInverted,
) )
drawColorPreview(dst.Data(), dst.Stride, dst.Width, dst.Height, px, py, picked, s.displayFormat, s.lowercase, s.screenFormat) drawColorPreview(dst.Data(), dst.Stride, dst.Width, dst.Height, px, py, picked, s.displayFormat, s.lowercase)
return dst return dst
} }
@@ -422,7 +390,6 @@ func drawMagnifierWithInversion(
cx, cy int, cx, cy int,
borderColor Color, borderColor Color,
yInverted bool, yInverted bool,
format PixelFormat,
) { ) {
if dstW <= 0 || dstH <= 0 || srcW <= 0 || srcH <= 0 { if dstW <= 0 || dstH <= 0 || srcW <= 0 || srcH <= 0 {
return return
@@ -440,14 +407,6 @@ func drawMagnifierWithInversion(
innerRadius := float64(outerRadius - borderThickness) innerRadius := float64(outerRadius - borderThickness)
outerRadiusF := float64(outerRadius) outerRadiusF := float64(outerRadius)
var rOff, bOff int
switch format {
case FormatABGR8888, FormatXBGR8888:
rOff, bOff = 0, 2
default:
rOff, bOff = 2, 0
}
for dy := -outerRadius - 2; dy <= outerRadius+2; dy++ { for dy := -outerRadius - 2; dy <= outerRadius+2; dy++ {
y := cy + dy y := cy + dy
if y < 0 || y >= dstH { if y < 0 || y >= dstH {
@@ -472,9 +431,9 @@ func drawMagnifierWithInversion(
} }
bgColor := Color{ bgColor := Color{
R: dst[dstOff+rOff], B: dst[dstOff+0],
G: dst[dstOff+1], G: dst[dstOff+1],
B: dst[dstOff+bOff], R: dst[dstOff+2],
A: dst[dstOff+3], A: dst[dstOff+3],
} }
@@ -503,7 +462,7 @@ func drawMagnifierWithInversion(
} }
srcOff := sy*srcStride + sx*4 srcOff := sy*srcStride + sx*4
if srcOff+4 <= len(src) { if srcOff+4 <= len(src) {
magColor := Color{R: src[srcOff+rOff], G: src[srcOff+1], B: src[srcOff+bOff], A: 255} magColor := Color{B: src[srcOff+0], G: src[srcOff+1], R: src[srcOff+2], A: 255}
finalColor = blendColors(magColor, borderColor, alpha) finalColor = blendColors(magColor, borderColor, alpha)
} else { } else {
finalColor = borderColor finalColor = borderColor
@@ -524,25 +483,24 @@ func drawMagnifierWithInversion(
} }
srcOff := sy*srcStride + sx*4 srcOff := sy*srcStride + sx*4
if srcOff+4 <= len(src) { if srcOff+4 <= len(src) {
finalColor = Color{R: src[srcOff+rOff], G: src[srcOff+1], B: src[srcOff+bOff], A: 255} finalColor = Color{B: src[srcOff+0], G: src[srcOff+1], R: src[srcOff+2], A: 255}
} else { } else {
continue continue
} }
} }
dst[dstOff+rOff] = finalColor.R dst[dstOff+0] = finalColor.B
dst[dstOff+1] = finalColor.G dst[dstOff+1] = finalColor.G
dst[dstOff+bOff] = finalColor.B dst[dstOff+2] = finalColor.R
dst[dstOff+3] = 255 dst[dstOff+3] = 255
} }
} }
drawMagnifierCrosshair(dst, dstStride, dstW, dstH, cx, cy, int(innerRadius), crossThickness, crossInnerRadius, format) drawMagnifierCrosshair(dst, dstStride, dstW, dstH, cx, cy, int(innerRadius), crossThickness, crossInnerRadius)
} }
func drawMagnifierCrosshair( func drawMagnifierCrosshair(
data []byte, stride, width, height, cx, cy, radius, thickness, innerRadius int, data []byte, stride, width, height, cx, cy, radius, thickness, innerRadius int,
format PixelFormat,
) { ) {
if width <= 0 || height <= 0 { if width <= 0 || height <= 0 {
return return
@@ -1040,7 +998,7 @@ var fontGlyphs = map[rune][fontH]uint8{
}, },
} }
func drawColorPreview(data []byte, stride, width, height int, cx, cy int, c Color, format OutputFormat, lowercase bool, pixelFormat PixelFormat) { func drawColorPreview(data []byte, stride, width, height int, cx, cy int, c Color, format OutputFormat, lowercase bool) {
text := formatColorForPreview(c, format, lowercase) text := formatColorForPreview(c, format, lowercase)
if len(text) == 0 { if len(text) == 0 {
return return
@@ -1075,8 +1033,9 @@ func drawColorPreview(data []byte, stride, width, height int, cx, cy int, c Colo
y = height - boxH y = height - boxH
} }
drawFilledRect(data, stride, width, height, x, y, boxW, boxH, c, pixelFormat) drawFilledRect(data, stride, width, height, x, y, boxW, boxH, c)
// Use contrasting text color based on luminance
lum := 0.299*float64(c.R) + 0.587*float64(c.G) + 0.114*float64(c.B) lum := 0.299*float64(c.R) + 0.587*float64(c.G) + 0.114*float64(c.B)
var fg Color var fg Color
if lum > 128 { if lum > 128 {
@@ -1084,7 +1043,7 @@ func drawColorPreview(data []byte, stride, width, height int, cx, cy int, c Colo
} else { } else {
fg = Color{R: 255, G: 255, B: 255, A: 255} fg = Color{R: 255, G: 255, B: 255, A: 255}
} }
drawText(data, stride, width, height, x+paddingX, y+paddingY, text, fg, pixelFormat) drawText(data, stride, width, height, x+paddingX, y+paddingY, text, fg)
} }
func formatColorForPreview(c Color, format OutputFormat, lowercase bool) string { func formatColorForPreview(c Color, format OutputFormat, lowercase bool) string {
@@ -1105,7 +1064,7 @@ func formatColorForPreview(c Color, format OutputFormat, lowercase bool) string
} }
} }
func drawFilledRect(data []byte, stride, width, height, x, y, w, h int, col Color, format PixelFormat) { func drawFilledRect(data []byte, stride, width, height, x, y, w, h int, col Color) {
if w <= 0 || h <= 0 { if w <= 0 || h <= 0 {
return return
} }
@@ -1114,14 +1073,6 @@ func drawFilledRect(data []byte, stride, width, height, x, y, w, h int, col Colo
x = clamp(x, 0, width) x = clamp(x, 0, width)
y = clamp(y, 0, height) y = clamp(y, 0, height)
var rOff, bOff int
switch format {
case FormatABGR8888, FormatXBGR8888:
rOff, bOff = 0, 2
default:
rOff, bOff = 2, 0
}
for yy := y; yy < yEnd; yy++ { for yy := y; yy < yEnd; yy++ {
rowOff := yy * stride rowOff := yy * stride
for xx := x; xx < xEnd; xx++ { for xx := x; xx < xEnd; xx++ {
@@ -1129,34 +1080,26 @@ func drawFilledRect(data []byte, stride, width, height, x, y, w, h int, col Colo
if off+4 > len(data) { if off+4 > len(data) {
continue continue
} }
data[off+rOff] = col.R data[off+0] = col.B
data[off+1] = col.G data[off+1] = col.G
data[off+bOff] = col.B data[off+2] = col.R
data[off+3] = 255 data[off+3] = 255
} }
} }
} }
func drawText(data []byte, stride, width, height, x, y int, text string, col Color, format PixelFormat) { func drawText(data []byte, stride, width, height, x, y int, text string, col Color) {
for i, r := range text { for i, r := range text {
drawGlyph(data, stride, width, height, x+i*(fontW+2), y, r, col, format) drawGlyph(data, stride, width, height, x+i*(fontW+2), y, r, col)
} }
} }
func drawGlyph(data []byte, stride, width, height, x, y int, r rune, col Color, format PixelFormat) { func drawGlyph(data []byte, stride, width, height, x, y int, r rune, col Color) {
g, ok := fontGlyphs[r] g, ok := fontGlyphs[r]
if !ok { if !ok {
return return
} }
var rOff, bOff int
switch format {
case FormatABGR8888, FormatXBGR8888:
rOff, bOff = 0, 2
default:
rOff, bOff = 2, 0
}
for row := 0; row < fontH; row++ { for row := 0; row < fontH; row++ {
yy := y + row yy := y + row
if yy < 0 || yy >= height { if yy < 0 || yy >= height {
@@ -1180,9 +1123,9 @@ func drawGlyph(data []byte, stride, width, height, x, y int, r rune, col Color,
continue continue
} }
data[off+rOff] = col.R data[off+0] = col.B
data[off+1] = col.G data[off+1] = col.G
data[off+bOff] = col.B data[off+2] = col.R
data[off+3] = 255 data[off+3] = 255
} }
} }

View File

@@ -46,20 +46,11 @@ func (cd *ConfigDeployer) DeployConfigurationsWithTerminal(ctx context.Context,
return cd.DeployConfigurationsSelective(ctx, wm, terminal, nil, nil) return cd.DeployConfigurationsSelective(ctx, wm, terminal, nil, nil)
} }
// DeployConfigurationsWithSystemd deploys configurations with systemd option
func (cd *ConfigDeployer) DeployConfigurationsWithSystemd(ctx context.Context, wm deps.WindowManager, terminal deps.Terminal, useSystemd bool) ([]DeploymentResult, error) {
return cd.deployConfigurationsInternal(ctx, wm, terminal, nil, nil, nil, useSystemd)
}
func (cd *ConfigDeployer) DeployConfigurationsSelective(ctx context.Context, wm deps.WindowManager, terminal deps.Terminal, installedDeps []deps.Dependency, replaceConfigs map[string]bool) ([]DeploymentResult, error) { func (cd *ConfigDeployer) DeployConfigurationsSelective(ctx context.Context, wm deps.WindowManager, terminal deps.Terminal, installedDeps []deps.Dependency, replaceConfigs map[string]bool) ([]DeploymentResult, error) {
return cd.DeployConfigurationsSelectiveWithReinstalls(ctx, wm, terminal, installedDeps, replaceConfigs, nil) return cd.DeployConfigurationsSelectiveWithReinstalls(ctx, wm, terminal, installedDeps, replaceConfigs, nil)
} }
func (cd *ConfigDeployer) DeployConfigurationsSelectiveWithReinstalls(ctx context.Context, wm deps.WindowManager, terminal deps.Terminal, installedDeps []deps.Dependency, replaceConfigs map[string]bool, reinstallItems map[string]bool) ([]DeploymentResult, error) { func (cd *ConfigDeployer) DeployConfigurationsSelectiveWithReinstalls(ctx context.Context, wm deps.WindowManager, terminal deps.Terminal, installedDeps []deps.Dependency, replaceConfigs map[string]bool, reinstallItems map[string]bool) ([]DeploymentResult, error) {
return cd.deployConfigurationsInternal(ctx, wm, terminal, installedDeps, replaceConfigs, reinstallItems, true)
}
func (cd *ConfigDeployer) deployConfigurationsInternal(ctx context.Context, wm deps.WindowManager, terminal deps.Terminal, installedDeps []deps.Dependency, replaceConfigs map[string]bool, reinstallItems map[string]bool, useSystemd bool) ([]DeploymentResult, error) {
var results []DeploymentResult var results []DeploymentResult
shouldReplaceConfig := func(configType string) bool { shouldReplaceConfig := func(configType string) bool {
@@ -73,7 +64,7 @@ func (cd *ConfigDeployer) deployConfigurationsInternal(ctx context.Context, wm d
switch wm { switch wm {
case deps.WindowManagerNiri: case deps.WindowManagerNiri:
if shouldReplaceConfig("Niri") { if shouldReplaceConfig("Niri") {
result, err := cd.deployNiriConfig(terminal, useSystemd) result, err := cd.deployNiriConfig(terminal)
results = append(results, result) results = append(results, result)
if err != nil { if err != nil {
return results, fmt.Errorf("failed to deploy Niri config: %w", err) return results, fmt.Errorf("failed to deploy Niri config: %w", err)
@@ -81,7 +72,7 @@ func (cd *ConfigDeployer) deployConfigurationsInternal(ctx context.Context, wm d
} }
case deps.WindowManagerHyprland: case deps.WindowManagerHyprland:
if shouldReplaceConfig("Hyprland") { if shouldReplaceConfig("Hyprland") {
result, err := cd.deployHyprlandConfig(terminal, useSystemd) result, err := cd.deployHyprlandConfig(terminal)
results = append(results, result) results = append(results, result)
if err != nil { if err != nil {
return results, fmt.Errorf("failed to deploy Hyprland config: %w", err) return results, fmt.Errorf("failed to deploy Hyprland config: %w", err)
@@ -119,7 +110,7 @@ func (cd *ConfigDeployer) deployConfigurationsInternal(ctx context.Context, wm d
return results, nil return results, nil
} }
func (cd *ConfigDeployer) deployNiriConfig(terminal deps.Terminal, useSystemd bool) (DeploymentResult, error) { func (cd *ConfigDeployer) deployNiriConfig(terminal deps.Terminal) (DeploymentResult, error) {
result := DeploymentResult{ result := DeploymentResult{
ConfigType: "Niri", ConfigType: "Niri",
Path: filepath.Join(os.Getenv("HOME"), ".config", "niri", "config.kdl"), Path: filepath.Join(os.Getenv("HOME"), ".config", "niri", "config.kdl"),
@@ -157,6 +148,12 @@ func (cd *ConfigDeployer) deployNiriConfig(terminal deps.Terminal, useSystemd bo
cd.log(fmt.Sprintf("Backed up existing config to %s", result.BackupPath)) cd.log(fmt.Sprintf("Backed up existing config to %s", result.BackupPath))
} }
polkitPath, err := cd.detectPolkitAgent()
if err != nil {
cd.log(fmt.Sprintf("Warning: Could not detect polkit agent: %v", err))
polkitPath = "/usr/lib/mate-polkit/polkit-mate-authentication-agent-1"
}
var terminalCommand string var terminalCommand string
switch terminal { switch terminal {
case deps.TerminalGhostty: case deps.TerminalGhostty:
@@ -169,11 +166,8 @@ func (cd *ConfigDeployer) deployNiriConfig(terminal deps.Terminal, useSystemd bo
terminalCommand = "ghostty" terminalCommand = "ghostty"
} }
newConfig := strings.ReplaceAll(NiriConfig, "{{TERMINAL_COMMAND}}", terminalCommand) newConfig := strings.ReplaceAll(NiriConfig, "{{POLKIT_AGENT_PATH}}", polkitPath)
newConfig = strings.ReplaceAll(newConfig, "{{TERMINAL_COMMAND}}", terminalCommand)
if !useSystemd {
newConfig = cd.transformNiriConfigForNonSystemd(newConfig, terminalCommand)
}
if existingConfig != "" { if existingConfig != "" {
mergedConfig, err := cd.mergeNiriOutputSections(newConfig, existingConfig) mergedConfig, err := cd.mergeNiriOutputSections(newConfig, existingConfig)
@@ -410,6 +404,41 @@ func (cd *ConfigDeployer) deployAlacrittyConfig() ([]DeploymentResult, error) {
return results, nil return results, nil
} }
// detectPolkitAgent tries to find the polkit authentication agent on the system
// Prioritizes mate-polkit paths since that's what we install
func (cd *ConfigDeployer) detectPolkitAgent() (string, error) {
// Prioritize mate-polkit paths first
matePaths := []string{
"/usr/libexec/polkit-mate-authentication-agent-1", // Fedora path
"/usr/lib/mate-polkit/polkit-mate-authentication-agent-1",
"/usr/libexec/mate-polkit/polkit-mate-authentication-agent-1",
"/usr/lib/polkit-mate/polkit-mate-authentication-agent-1",
"/usr/lib/x86_64-linux-gnu/mate-polkit/polkit-mate-authentication-agent-1",
}
for _, path := range matePaths {
if _, err := os.Stat(path); err == nil {
cd.log(fmt.Sprintf("Found mate-polkit agent at: %s", path))
return path, nil
}
}
// Fallback to other polkit agents if mate-polkit is not found
fallbackPaths := []string{
"/usr/lib/polkit-gnome/polkit-gnome-authentication-agent-1",
"/usr/libexec/polkit-gnome-authentication-agent-1",
}
for _, path := range fallbackPaths {
if _, err := os.Stat(path); err == nil {
cd.log(fmt.Sprintf("Found fallback polkit agent at: %s", path))
return path, nil
}
}
return "", fmt.Errorf("no polkit agent found in common locations")
}
// mergeNiriOutputSections extracts output sections from existing config and merges them into the new config // mergeNiriOutputSections extracts output sections from existing config and merges them into the new config
func (cd *ConfigDeployer) mergeNiriOutputSections(newConfig, existingConfig string) (string, error) { func (cd *ConfigDeployer) mergeNiriOutputSections(newConfig, existingConfig string) (string, error) {
// Regular expression to match output sections (including commented ones) // Regular expression to match output sections (including commented ones)
@@ -453,7 +482,7 @@ func (cd *ConfigDeployer) mergeNiriOutputSections(newConfig, existingConfig stri
} }
// deployHyprlandConfig handles Hyprland configuration deployment with backup and merging // deployHyprlandConfig handles Hyprland configuration deployment with backup and merging
func (cd *ConfigDeployer) deployHyprlandConfig(terminal deps.Terminal, useSystemd bool) (DeploymentResult, error) { func (cd *ConfigDeployer) deployHyprlandConfig(terminal deps.Terminal) (DeploymentResult, error) {
result := DeploymentResult{ result := DeploymentResult{
ConfigType: "Hyprland", ConfigType: "Hyprland",
Path: filepath.Join(os.Getenv("HOME"), ".config", "hypr", "hyprland.conf"), Path: filepath.Join(os.Getenv("HOME"), ".config", "hypr", "hyprland.conf"),
@@ -485,6 +514,14 @@ func (cd *ConfigDeployer) deployHyprlandConfig(terminal deps.Terminal, useSystem
cd.log(fmt.Sprintf("Backed up existing config to %s", result.BackupPath)) cd.log(fmt.Sprintf("Backed up existing config to %s", result.BackupPath))
} }
// Detect polkit agent path
polkitPath, err := cd.detectPolkitAgent()
if err != nil {
cd.log(fmt.Sprintf("Warning: Could not detect polkit agent: %v", err))
polkitPath = "/usr/lib/mate-polkit/polkit-mate-authentication-agent-1" // fallback
}
// Determine terminal command based on choice
var terminalCommand string var terminalCommand string
switch terminal { switch terminal {
case deps.TerminalGhostty: case deps.TerminalGhostty:
@@ -494,15 +531,13 @@ func (cd *ConfigDeployer) deployHyprlandConfig(terminal deps.Terminal, useSystem
case deps.TerminalAlacritty: case deps.TerminalAlacritty:
terminalCommand = "alacritty" terminalCommand = "alacritty"
default: default:
terminalCommand = "ghostty" terminalCommand = "ghostty" // fallback to ghostty
} }
newConfig := strings.ReplaceAll(HyprlandConfig, "{{TERMINAL_COMMAND}}", terminalCommand) newConfig := strings.ReplaceAll(HyprlandConfig, "{{POLKIT_AGENT_PATH}}", polkitPath)
newConfig = strings.ReplaceAll(newConfig, "{{TERMINAL_COMMAND}}", terminalCommand)
if !useSystemd {
newConfig = cd.transformHyprlandConfigForNonSystemd(newConfig, terminalCommand)
}
// If there was an existing config, merge the monitor sections
if existingConfig != "" { if existingConfig != "" {
mergedConfig, err := cd.mergeHyprlandMonitorSections(newConfig, existingConfig) mergedConfig, err := cd.mergeHyprlandMonitorSections(newConfig, existingConfig)
if err != nil { if err != nil {
@@ -525,16 +560,24 @@ func (cd *ConfigDeployer) deployHyprlandConfig(terminal deps.Terminal, useSystem
// mergeHyprlandMonitorSections extracts monitor sections from existing config and merges them into the new config // mergeHyprlandMonitorSections extracts monitor sections from existing config and merges them into the new config
func (cd *ConfigDeployer) mergeHyprlandMonitorSections(newConfig, existingConfig string) (string, error) { func (cd *ConfigDeployer) mergeHyprlandMonitorSections(newConfig, existingConfig string) (string, error) {
// Regular expression to match monitor lines (including commented ones)
// Matches: monitor = NAME, RESOLUTION, POSITION, SCALE, etc.
// Also matches commented versions: # monitor = ...
monitorRegex := regexp.MustCompile(`(?m)^#?\s*monitor\s*=.*$`) monitorRegex := regexp.MustCompile(`(?m)^#?\s*monitor\s*=.*$`)
// Find all monitor lines in the existing config
existingMonitors := monitorRegex.FindAllString(existingConfig, -1) existingMonitors := monitorRegex.FindAllString(existingConfig, -1)
if len(existingMonitors) == 0 { if len(existingMonitors) == 0 {
// No monitor sections to merge
return newConfig, nil return newConfig, nil
} }
// Remove the example monitor line from the new config
exampleMonitorRegex := regexp.MustCompile(`(?m)^# monitor = eDP-2.*$`) exampleMonitorRegex := regexp.MustCompile(`(?m)^# monitor = eDP-2.*$`)
mergedConfig := exampleMonitorRegex.ReplaceAllString(newConfig, "") mergedConfig := exampleMonitorRegex.ReplaceAllString(newConfig, "")
// Find where to insert the monitor sections (after the MONITOR CONFIG header)
monitorHeaderRegex := regexp.MustCompile(`(?m)^# MONITOR CONFIG\n# ==================$`) monitorHeaderRegex := regexp.MustCompile(`(?m)^# MONITOR CONFIG\n# ==================$`)
headerMatch := monitorHeaderRegex.FindStringIndex(mergedConfig) headerMatch := monitorHeaderRegex.FindStringIndex(mergedConfig)
@@ -542,7 +585,8 @@ func (cd *ConfigDeployer) mergeHyprlandMonitorSections(newConfig, existingConfig
return "", fmt.Errorf("could not find MONITOR CONFIG section") return "", fmt.Errorf("could not find MONITOR CONFIG section")
} }
insertPos := headerMatch[1] + 1 // Insert after the header
insertPos := headerMatch[1] + 1 // +1 for the newline
var builder strings.Builder var builder strings.Builder
builder.WriteString(mergedConfig[:insertPos]) builder.WriteString(mergedConfig[:insertPos])
@@ -557,69 +601,3 @@ func (cd *ConfigDeployer) mergeHyprlandMonitorSections(newConfig, existingConfig
return builder.String(), nil return builder.String(), nil
} }
func (cd *ConfigDeployer) transformHyprlandConfigForNonSystemd(config, terminalCommand string) string {
lines := strings.Split(config, "\n")
var result []string
startupSectionFound := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "exec-once = dbus-update-activation-environment") {
continue
}
if strings.HasPrefix(trimmed, "exec-once = systemctl --user start") {
startupSectionFound = true
result = append(result, "exec-once = dms run")
result = append(result, "env = QT_QPA_PLATFORM,wayland")
result = append(result, "env = ELECTRON_OZONE_PLATFORM_HINT,auto")
result = append(result, "env = QT_QPA_PLATFORMTHEME,gtk3")
result = append(result, "env = QT_QPA_PLATFORMTHEME_QT6,gtk3")
result = append(result, fmt.Sprintf("env = TERMINAL,%s", terminalCommand))
continue
}
result = append(result, line)
}
if !startupSectionFound {
for i, line := range result {
if strings.Contains(line, "STARTUP APPS") {
insertLines := []string{
"exec-once = dms run",
"env = QT_QPA_PLATFORM,wayland",
"env = ELECTRON_OZONE_PLATFORM_HINT,auto",
"env = QT_QPA_PLATFORMTHEME,gtk3",
"env = QT_QPA_PLATFORMTHEME_QT6,gtk3",
fmt.Sprintf("env = TERMINAL,%s", terminalCommand),
}
result = append(result[:i+2], append(insertLines, result[i+2:]...)...)
break
}
}
}
return strings.Join(result, "\n")
}
func (cd *ConfigDeployer) transformNiriConfigForNonSystemd(config, terminalCommand string) string {
envVars := fmt.Sprintf(`environment {
XDG_CURRENT_DESKTOP "niri"
QT_QPA_PLATFORM "wayland"
ELECTRON_OZONE_PLATFORM_HINT "auto"
QT_QPA_PLATFORMTHEME "gtk3"
QT_QPA_PLATFORMTHEME_QT6 "gtk3"
TERMINAL "%s"
}`, terminalCommand)
config = regexp.MustCompile(`environment \{[^}]*\}`).ReplaceAllString(config, envVars)
spawnDms := `spawn-at-startup "dms" "run"`
if !strings.Contains(config, spawnDms) {
config = strings.Replace(config,
`spawn-at-startup "bash" "-c" "wl-paste --watch cliphist store &"`,
`spawn-at-startup "bash" "-c" "wl-paste --watch cliphist store &"`+"\n"+spawnDms,
1)
}
return config
}

View File

@@ -3,6 +3,7 @@ package config
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"github.com/AvengeMedia/DankMaterialShell/core/internal/deps" "github.com/AvengeMedia/DankMaterialShell/core/internal/deps"
@@ -10,6 +11,23 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestDetectPolkitAgent(t *testing.T) {
cd := &ConfigDeployer{}
// This test depends on the system having a polkit agent installed
// We'll just test that the function doesn't crash and returns some path or error
path, err := cd.detectPolkitAgent()
if err != nil {
// If no polkit agent is found, that's okay for testing
assert.Contains(t, err.Error(), "no polkit agent found")
} else {
// If found, it should be a valid path
assert.NotEmpty(t, path)
assert.True(t, strings.Contains(path, "polkit"))
}
}
func TestMergeNiriOutputSections(t *testing.T) { func TestMergeNiriOutputSections(t *testing.T) {
cd := &ConfigDeployer{} cd := &ConfigDeployer{}
@@ -254,6 +272,17 @@ func getGhosttyPath() string {
return filepath.Join(os.Getenv("HOME"), ".config", "ghostty", "config") return filepath.Join(os.Getenv("HOME"), ".config", "ghostty", "config")
} }
func TestPolkitPathInjection(t *testing.T) {
testConfig := `spawn-at-startup "{{POLKIT_AGENT_PATH}}"
other content`
result := strings.Replace(testConfig, "{{POLKIT_AGENT_PATH}}", "/test/polkit/path", 1)
assert.Contains(t, result, `spawn-at-startup "/test/polkit/path"`)
assert.NotContains(t, result, "{{POLKIT_AGENT_PATH}}")
}
func TestMergeHyprlandMonitorSections(t *testing.T) { func TestMergeHyprlandMonitorSections(t *testing.T) {
cd := &ConfigDeployer{} cd := &ConfigDeployer{}
@@ -395,7 +424,7 @@ func TestHyprlandConfigDeployment(t *testing.T) {
cd := NewConfigDeployer(logChan) cd := NewConfigDeployer(logChan)
t.Run("deploy hyprland config to empty directory", func(t *testing.T) { t.Run("deploy hyprland config to empty directory", func(t *testing.T) {
result, err := cd.deployHyprlandConfig(deps.TerminalGhostty, true) result, err := cd.deployHyprlandConfig(deps.TerminalGhostty)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Hyprland", result.ConfigType) assert.Equal(t, "Hyprland", result.ConfigType)
@@ -406,7 +435,7 @@ func TestHyprlandConfigDeployment(t *testing.T) {
content, err := os.ReadFile(result.Path) content, err := os.ReadFile(result.Path)
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, string(content), "# MONITOR CONFIG") assert.Contains(t, string(content), "# MONITOR CONFIG")
assert.Contains(t, string(content), "bind = $mod, T, exec, ghostty") assert.Contains(t, string(content), "bind = $mod, T, exec, $TERMINAL")
assert.Contains(t, string(content), "exec-once = ") assert.Contains(t, string(content), "exec-once = ")
}) })
@@ -425,7 +454,7 @@ general {
err = os.WriteFile(hyprPath, []byte(existingContent), 0644) err = os.WriteFile(hyprPath, []byte(existingContent), 0644)
require.NoError(t, err) require.NoError(t, err)
result, err := cd.deployHyprlandConfig(deps.TerminalKitty, true) result, err := cd.deployHyprlandConfig(deps.TerminalKitty)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "Hyprland", result.ConfigType) assert.Equal(t, "Hyprland", result.ConfigType)
@@ -442,7 +471,7 @@ general {
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, string(newContent), "monitor = DP-1, 1920x1080@144") assert.Contains(t, string(newContent), "monitor = DP-1, 1920x1080@144")
assert.Contains(t, string(newContent), "monitor = HDMI-A-1, 3840x2160@60") assert.Contains(t, string(newContent), "monitor = HDMI-A-1, 3840x2160@60")
assert.Contains(t, string(newContent), "bind = $mod, T, exec, kitty") assert.Contains(t, string(newContent), "bind = $mod, T, exec, $TERMINAL")
assert.NotContains(t, string(newContent), "monitor = eDP-2") assert.NotContains(t, string(newContent), "monitor = eDP-2")
}) })
} }
@@ -450,6 +479,7 @@ general {
func TestNiriConfigStructure(t *testing.T) { func TestNiriConfigStructure(t *testing.T) {
assert.Contains(t, NiriConfig, "input {") assert.Contains(t, NiriConfig, "input {")
assert.Contains(t, NiriConfig, "layout {") assert.Contains(t, NiriConfig, "layout {")
assert.Contains(t, NiriConfig, "{{POLKIT_AGENT_PATH}}")
assert.Contains(t, NiriBindsConfig, "binds {") assert.Contains(t, NiriBindsConfig, "binds {")
assert.Contains(t, NiriBindsConfig, `spawn "{{TERMINAL_COMMAND}}"`) assert.Contains(t, NiriBindsConfig, `spawn "{{TERMINAL_COMMAND}}"`)
@@ -460,9 +490,11 @@ func TestHyprlandConfigStructure(t *testing.T) {
assert.Contains(t, HyprlandConfig, "# STARTUP APPS") assert.Contains(t, HyprlandConfig, "# STARTUP APPS")
assert.Contains(t, HyprlandConfig, "# INPUT CONFIG") assert.Contains(t, HyprlandConfig, "# INPUT CONFIG")
assert.Contains(t, HyprlandConfig, "# KEYBINDINGS") assert.Contains(t, HyprlandConfig, "# KEYBINDINGS")
assert.Contains(t, HyprlandConfig, "bind = $mod, T, exec, {{TERMINAL_COMMAND}}") assert.Contains(t, HyprlandConfig, "{{POLKIT_AGENT_PATH}}")
assert.Contains(t, HyprlandConfig, "bind = $mod, T, exec, $TERMINAL")
assert.Contains(t, HyprlandConfig, "bind = $mod, space, exec, dms ipc call spotlight toggle") assert.Contains(t, HyprlandConfig, "bind = $mod, space, exec, dms ipc call spotlight toggle")
assert.Contains(t, HyprlandConfig, "windowrulev2 = noborder, class:^(com\\.mitchellh\\.ghostty)$") assert.Contains(t, HyprlandConfig, "windowrule {")
assert.Contains(t, HyprlandConfig, "match:class = ^(com\\.mitchellh\\.ghostty)$")
} }
func TestGhosttyConfigStructure(t *testing.T) { func TestGhosttyConfigStructure(t *testing.T) {

View File

@@ -5,14 +5,19 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
// LocateDMSConfig searches for DMS installation following XDG Base Directory specification
func LocateDMSConfig() (string, error) { func LocateDMSConfig() (string, error) {
var primaryPaths []string var primaryPaths []string
configHome := utils.XDGConfigHome() configHome := os.Getenv("XDG_CONFIG_HOME")
if configHome == "" {
if homeDir, err := os.UserHomeDir(); err == nil {
configHome = filepath.Join(homeDir, ".config")
}
}
if configHome != "" { if configHome != "" {
primaryPaths = append(primaryPaths, filepath.Join(configHome, "quickshell", "dms")) primaryPaths = append(primaryPaths, filepath.Join(configHome, "quickshell", "dms"))
} }

View File

@@ -10,9 +10,8 @@ monitor = , preferred,auto,auto
# ================== # ==================
# STARTUP APPS # STARTUP APPS
# ================== # ==================
exec-once = dbus-update-activation-environment --systemd --all
exec-once = systemctl --user start hyprland-session.target
exec-once = bash -c "wl-paste --watch cliphist store &" exec-once = bash -c "wl-paste --watch cliphist store &"
exec-once = {{POLKIT_AGENT_PATH}}
# ================== # ==================
# INPUT CONFIG # INPUT CONFIG
@@ -91,36 +90,132 @@ misc {
# ================== # ==================
# WINDOW RULES # WINDOW RULES
# ================== # ==================
windowrulev2 = tile, class:^(org\.wezfurlong\.wezterm)$ windowrule {
name = windowrule-1
tile = on
match:class = ^(org\.wezfurlong\.wezterm)$
border_size = 0
}
windowrulev2 = rounding 12, class:^(org\.gnome\.)
windowrulev2 = noborder, class:^(org\.gnome\.)
windowrulev2 = tile, class:^(gnome-control-center)$ windowrule {
windowrulev2 = tile, class:^(pavucontrol)$ name = windowrule-2
windowrulev2 = tile, class:^(nm-connection-editor)$ rounding = 12
match:class = ^(org\.gnome\.)
border_size = 0
}
windowrulev2 = float, class:^(gnome-calculator)$
windowrulev2 = float, class:^(galculator)$
windowrulev2 = float, class:^(blueman-manager)$
windowrulev2 = float, class:^(org\.gnome\.Nautilus)$
windowrulev2 = float, class:^(steam)$
windowrulev2 = float, class:^(xdg-desktop-portal)$
windowrulev2 = noborder, class:^(org\.wezfurlong\.wezterm)$
windowrulev2 = noborder, class:^(Alacritty)$
windowrulev2 = noborder, class:^(zen)$
windowrulev2 = noborder, class:^(com\.mitchellh\.ghostty)$
windowrulev2 = noborder, class:^(kitty)$
windowrulev2 = float, class:^(firefox)$, title:^(Picture-in-Picture)$ windowrule {
windowrulev2 = float, class:^(zoom)$ name = windowrule-3
tile = on
match:class = ^(gnome-control-center)$
}
# DMS windows floating by default windowrule {
windowrulev2 = float, class:^(org.quickshell)$ name = windowrule-4
windowrulev2 = opacity 0.9 0.9, floating:0, focus:0 tile = on
match:class = ^(pavucontrol)$
}
layerrule = noanim, ^(quickshell)$ windowrule {
name = windowrule-5
tile = on
match:class = ^(nm-connection-editor)$
}
windowrule {
name = windowrule-6
float = on
match:class = ^(gnome-calculator)$
}
windowrule {
name = windowrule-7
float = on
match:class = ^(galculator)$
}
windowrule {
name = windowrule-8
float = on
match:class = ^(blueman-manager)$
}
windowrule {
name = windowrule-9
float = on
match:class = ^(org\.gnome\.Nautilus)$
}
windowrule {
name = windowrule-10
float = on
match:class = ^(steam)$
}
windowrule {
name = windowrule-11
float = on
match:class = ^(xdg-desktop-portal)$
}
windowrule {
name = windowrule-12
border_size = 0
match:class = ^(Alacritty)$
}
windowrule {
name = windowrule-13
border_size = 0
match:class = ^(zen)$
}
windowrule {
name = windowrule-14
border_size = 0
match:class = ^(com\.mitchellh\.ghostty)$
}
windowrule {
name = windowrule-15
border_size = 0
match:class = ^(kitty)$
}
windowrule {
name = windowrule-16
float = on
match:class = ^(firefox)$
match:title = ^(Picture-in-Picture)$
}
windowrule {
name = windowrule-17
float = on
match:class = ^(zoom)$
}
windowrule {
name = windowrule-18
opacity = 0.9 0.9
match:float = 0
match:focus = 0
}
layerrule {
name = layerrule-1
no_anim = on
match:namespace = ^(quickshell)$
}
# ================== # ==================
# KEYBINDINGS # KEYBINDINGS
@@ -128,7 +223,7 @@ layerrule = noanim, ^(quickshell)$
$mod = SUPER $mod = SUPER
# === Application Launchers === # === Application Launchers ===
bind = $mod, T, exec, {{TERMINAL_COMMAND}} bind = $mod, T, exec, $TERMINAL
bind = $mod, space, exec, dms ipc call spotlight toggle bind = $mod, space, exec, dms ipc call spotlight toggle
bind = $mod, V, exec, dms ipc call clipboard toggle bind = $mod, V, exec, dms ipc call clipboard toggle
bind = $mod, M, exec, dms ipc call processlist focusOrToggle bind = $mod, M, exec, dms ipc call processlist focusOrToggle

View File

@@ -44,6 +44,7 @@ input {
// https://github.com/YaLTeR/niri/wiki/Configuration:-Layout // https://github.com/YaLTeR/niri/wiki/Configuration:-Layout
layout { layout {
// Set gaps around windows in logical pixels. // Set gaps around windows in logical pixels.
gaps 5
background-color "transparent" background-color "transparent"
// When to center a column when changing focus, options are: // When to center a column when changing focus, options are:
// - "never", default behavior, focusing an off-screen column will keep at the left // - "never", default behavior, focusing an off-screen column will keep at the left
@@ -86,6 +87,11 @@ layout {
inactive-color "#d0d0d0" // Light gray inactive-color "#d0d0d0" // Light gray
urgent-color "#cc4444" // Softer red urgent-color "#cc4444" // Softer red
} }
focus-ring {
width 2
active-color "#808080" // Medium gray
inactive-color "#505050" // Dark gray
}
shadow { shadow {
softness 30 softness 30
spread 5 spread 5
@@ -110,6 +116,7 @@ overview {
// See the binds section below for more spawn examples. // See the binds section below for more spawn examples.
// This line starts waybar, a commonly used bar for Wayland compositors. // This line starts waybar, a commonly used bar for Wayland compositors.
spawn-at-startup "bash" "-c" "wl-paste --watch cliphist store &" spawn-at-startup "bash" "-c" "wl-paste --watch cliphist store &"
spawn-at-startup "{{POLKIT_AGENT_PATH}}"
environment { environment {
XDG_CURRENT_DESKTOP "niri" XDG_CURRENT_DESKTOP "niri"
} }

View File

@@ -113,14 +113,13 @@ func RGBToHSV(rgb RGB) HSV {
delta := max - min delta := max - min
var h float64 var h float64
switch { if delta == 0 {
case delta == 0:
h = 0 h = 0
case max == rgb.R: } else if max == rgb.R {
h = math.Mod((rgb.G-rgb.B)/delta, 6.0) / 6.0 h = math.Mod((rgb.G-rgb.B)/delta, 6.0) / 6.0
case max == rgb.G: } else if max == rgb.G {
h = ((rgb.B-rgb.R)/delta + 2.0) / 6.0 h = ((rgb.B-rgb.R)/delta + 2.0) / 6.0
default: } else {
h = ((rgb.R-rgb.G)/delta + 4.0) / 6.0 h = ((rgb.R-rgb.G)/delta + 4.0) / 6.0
} }

View File

@@ -112,11 +112,31 @@ func (a *ArchDistribution) DetectDependenciesWithTerminal(ctx context.Context, w
} }
func (a *ArchDistribution) detectXDGPortal() deps.Dependency { func (a *ArchDistribution) detectXDGPortal() deps.Dependency {
return a.detectPackage("xdg-desktop-portal-gtk", "Desktop integration portal for GTK", a.packageInstalled("xdg-desktop-portal-gtk")) status := deps.StatusMissing
if a.packageInstalled("xdg-desktop-portal-gtk") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "xdg-desktop-portal-gtk",
Status: status,
Description: "Desktop integration portal for GTK",
Required: true,
}
} }
func (a *ArchDistribution) detectAccountsService() deps.Dependency { func (a *ArchDistribution) detectAccountsService() deps.Dependency {
return a.detectPackage("accountsservice", "D-Bus interface for user account query and manipulation", a.packageInstalled("accountsservice")) status := deps.StatusMissing
if a.packageInstalled("accountsservice") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "accountsservice",
Status: status,
Description: "D-Bus interface for user account query and manipulation",
Required: true,
}
} }
func (a *ArchDistribution) packageInstalled(pkg string) bool { func (a *ArchDistribution) packageInstalled(pkg string) bool {
@@ -162,11 +182,13 @@ func (a *ArchDistribution) getQuickshellMapping(variant deps.PackageVariant) Pac
if forceQuickshellGit || variant == deps.VariantGit { if forceQuickshellGit || variant == deps.VariantGit {
return PackageMapping{Name: "quickshell-git", Repository: RepoTypeAUR} return PackageMapping{Name: "quickshell-git", Repository: RepoTypeAUR}
} }
// ! TODO - for now we're only forcing quickshell-git on ARCH, as other distros use DL repos which pin a newer quickshell return PackageMapping{Name: "quickshell", Repository: RepoTypeSystem}
return PackageMapping{Name: "quickshell-git", Repository: RepoTypeAUR}
} }
func (a *ArchDistribution) getHyprlandMapping(_ deps.PackageVariant) PackageMapping { func (a *ArchDistribution) getHyprlandMapping(variant deps.PackageVariant) PackageMapping {
if variant == deps.VariantGit {
return PackageMapping{Name: "hyprland-git", Repository: RepoTypeAUR}
}
return PackageMapping{Name: "hyprland", Repository: RepoTypeSystem} return PackageMapping{Name: "hyprland", Repository: RepoTypeSystem}
} }
@@ -340,11 +362,7 @@ func (a *ArchDistribution) InstallPackages(ctx context.Context, dependencies []d
a.log(fmt.Sprintf("Warning: failed to write environment config: %v", err)) a.log(fmt.Sprintf("Warning: failed to write environment config: %v", err))
} }
if err := a.WriteWindowManagerConfig(wm); err != nil { if err := a.EnableDMSService(ctx); err != nil {
a.log(fmt.Sprintf("Warning: failed to write window manager config: %v", err))
}
if err := a.EnableDMSService(ctx, wm); err != nil {
a.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err)) a.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err))
} }

View File

@@ -76,42 +76,47 @@ func ExecSudoCommand(ctx context.Context, sudoPassword string, command string) *
return exec.CommandContext(ctx, "bash", "-c", cmdStr) return exec.CommandContext(ctx, "bash", "-c", cmdStr)
} }
func (b *BaseDistribution) detectCommand(name, description string) deps.Dependency { // Common dependency detection methods
status := deps.StatusMissing
if b.commandExists(name) {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: name,
Status: status,
Description: description,
Required: true,
}
}
func (b *BaseDistribution) detectPackage(name, description string, installed bool) deps.Dependency {
status := deps.StatusMissing
if installed {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: name,
Status: status,
Description: description,
Required: true,
}
}
func (b *BaseDistribution) detectGit() deps.Dependency { func (b *BaseDistribution) detectGit() deps.Dependency {
return b.detectCommand("git", "Version control system") status := deps.StatusMissing
if b.commandExists("git") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "git",
Status: status,
Description: "Version control system",
Required: true,
}
} }
func (b *BaseDistribution) detectMatugen() deps.Dependency { func (b *BaseDistribution) detectMatugen() deps.Dependency {
return b.detectCommand("matugen", "Material Design color generation tool") status := deps.StatusMissing
if b.commandExists("matugen") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "matugen",
Status: status,
Description: "Material Design color generation tool",
Required: true,
}
} }
func (b *BaseDistribution) detectDgop() deps.Dependency { func (b *BaseDistribution) detectDgop() deps.Dependency {
return b.detectCommand("dgop", "Desktop portal management tool") status := deps.StatusMissing
if b.commandExists("dgop") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "dgop",
Status: status,
Description: "Desktop portal management tool",
Required: true,
}
} }
func (b *BaseDistribution) detectDMS() deps.Dependency { func (b *BaseDistribution) detectDMS() deps.Dependency {
@@ -597,59 +602,12 @@ TERMINAL=%s
return nil return nil
} }
func (b *BaseDistribution) EnableDMSService(ctx context.Context, wm deps.WindowManager) error { func (b *BaseDistribution) EnableDMSService(ctx context.Context) error {
cmd := exec.CommandContext(ctx, "systemctl", "--user", "enable", "--now", "dms") cmd := exec.CommandContext(ctx, "systemctl", "--user", "enable", "--now", "dms")
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to enable dms service: %w", err) return fmt.Errorf("failed to enable dms service: %w", err)
} }
b.log("Enabled dms systemd user service") b.log("Enabled dms systemd user service")
switch wm {
case deps.WindowManagerNiri:
if err := exec.CommandContext(ctx, "systemctl", "--user", "add-wants", "niri.service", "dms").Run(); err != nil {
b.log("Warning: failed to add dms as a want for niri.service")
}
case deps.WindowManagerHyprland:
if err := exec.CommandContext(ctx, "systemctl", "--user", "add-wants", "hyprland-session.target", "dms").Run(); err != nil {
b.log("Warning: failed to add dms as a want for hyprland-session.target")
}
}
return nil
}
func (b *BaseDistribution) WriteWindowManagerConfig(wm deps.WindowManager) error {
if wm == deps.WindowManagerHyprland {
if err := b.WriteHyprlandSessionTarget(); err != nil {
return fmt.Errorf("failed to write hyprland session target: %w", err)
}
}
return nil
}
func (b *BaseDistribution) WriteHyprlandSessionTarget() error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get home directory: %w", err)
}
targetDir := filepath.Join(homeDir, ".config", "systemd", "user")
if err := os.MkdirAll(targetDir, 0755); err != nil {
return fmt.Errorf("failed to create systemd user directory: %w", err)
}
targetPath := filepath.Join(targetDir, "hyprland-session.target")
content := `[Unit]
Description=Hyprland Session Target
Requires=graphical-session.target
After=graphical-session.target
`
if err := os.WriteFile(targetPath, []byte(content), 0644); err != nil {
return fmt.Errorf("failed to write hyprland-session.target: %w", err)
}
b.log(fmt.Sprintf("Wrote hyprland-session.target to %s", targetPath))
return nil return nil
} }

View File

@@ -7,7 +7,6 @@ import (
"testing" "testing"
"github.com/AvengeMedia/DankMaterialShell/core/internal/deps" "github.com/AvengeMedia/DankMaterialShell/core/internal/deps"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
func TestBaseDistribution_detectDMS_NotInstalled(t *testing.T) { func TestBaseDistribution_detectDMS_NotInstalled(t *testing.T) {
@@ -37,7 +36,7 @@ func TestBaseDistribution_detectDMS_NotInstalled(t *testing.T) {
} }
func TestBaseDistribution_detectDMS_Installed(t *testing.T) { func TestBaseDistribution_detectDMS_Installed(t *testing.T) {
if !utils.CommandExists("git") { if !commandExists("git") {
t.Skip("git not available") t.Skip("git not available")
} }
@@ -81,7 +80,7 @@ func TestBaseDistribution_detectDMS_Installed(t *testing.T) {
} }
func TestBaseDistribution_detectDMS_NeedsUpdate(t *testing.T) { func TestBaseDistribution_detectDMS_NeedsUpdate(t *testing.T) {
if !utils.CommandExists("git") { if !commandExists("git") {
t.Skip("git not available") t.Skip("git not available")
} }
@@ -165,6 +164,11 @@ func TestBaseDistribution_NewBaseDistribution(t *testing.T) {
} }
} }
func commandExists(cmd string) bool {
_, err := exec.LookPath(cmd)
return err == nil
}
func TestBaseDistribution_versionCompare(t *testing.T) { func TestBaseDistribution_versionCompare(t *testing.T) {
logChan := make(chan string, 10) logChan := make(chan string, 10)
defer close(logChan) defer close(logChan)

View File

@@ -75,15 +75,45 @@ func (d *DebianDistribution) DetectDependenciesWithTerminal(ctx context.Context,
} }
func (d *DebianDistribution) detectXDGPortal() deps.Dependency { func (d *DebianDistribution) detectXDGPortal() deps.Dependency {
return d.detectPackage("xdg-desktop-portal-gtk", "Desktop integration portal for GTK", d.packageInstalled("xdg-desktop-portal-gtk")) status := deps.StatusMissing
if d.packageInstalled("xdg-desktop-portal-gtk") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "xdg-desktop-portal-gtk",
Status: status,
Description: "Desktop integration portal for GTK",
Required: true,
}
} }
func (d *DebianDistribution) detectXwaylandSatellite() deps.Dependency { func (d *DebianDistribution) detectXwaylandSatellite() deps.Dependency {
return d.detectCommand("xwayland-satellite", "Xwayland support") status := deps.StatusMissing
if d.commandExists("xwayland-satellite") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "xwayland-satellite",
Status: status,
Description: "Xwayland support",
Required: true,
}
} }
func (d *DebianDistribution) detectAccountsService() deps.Dependency { func (d *DebianDistribution) detectAccountsService() deps.Dependency {
return d.detectPackage("accountsservice", "D-Bus interface for user account query and manipulation", d.packageInstalled("accountsservice")) status := deps.StatusMissing
if d.packageInstalled("accountsservice") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "accountsservice",
Status: status,
Description: "D-Bus interface for user account query and manipulation",
Required: true,
}
} }
func (d *DebianDistribution) packageInstalled(pkg string) bool { func (d *DebianDistribution) packageInstalled(pkg string) bool {
@@ -178,7 +208,7 @@ func (d *DebianDistribution) InstallPrerequisites(ctx context.Context, sudoPassw
checkCmd := exec.CommandContext(ctx, "dpkg", "-l", "build-essential") checkCmd := exec.CommandContext(ctx, "dpkg", "-l", "build-essential")
if err := checkCmd.Run(); err != nil { if err := checkCmd.Run(); err != nil {
cmd := ExecSudoCommand(ctx, sudoPassword, "DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential") cmd := ExecSudoCommand(ctx, sudoPassword, "apt-get install -y build-essential")
if err := d.runWithProgress(cmd, progressChan, PhasePrerequisites, 0.08, 0.09); err != nil { if err := d.runWithProgress(cmd, progressChan, PhasePrerequisites, 0.08, 0.09); err != nil {
return fmt.Errorf("failed to install build-essential: %w", err) return fmt.Errorf("failed to install build-essential: %w", err)
} }
@@ -195,7 +225,7 @@ func (d *DebianDistribution) InstallPrerequisites(ctx context.Context, sudoPassw
} }
devToolsCmd := ExecSudoCommand(ctx, sudoPassword, devToolsCmd := ExecSudoCommand(ctx, sudoPassword,
"DEBIAN_FRONTEND=noninteractive apt-get install -y curl wget git cmake ninja-build pkg-config libxcb-cursor-dev libglib2.0-dev libpolkit-agent-1-dev libjpeg-dev libpugixml-dev") "apt-get install -y curl wget git cmake ninja-build pkg-config libxcb-cursor-dev libglib2.0-dev libpolkit-agent-1-dev libjpeg-dev libpugixml-dev")
if err := d.runWithProgress(devToolsCmd, progressChan, PhasePrerequisites, 0.10, 0.12); err != nil { if err := d.runWithProgress(devToolsCmd, progressChan, PhasePrerequisites, 0.10, 0.12); err != nil {
return fmt.Errorf("failed to install development tools: %w", err) return fmt.Errorf("failed to install development tools: %w", err)
} }
@@ -308,11 +338,7 @@ func (d *DebianDistribution) InstallPackages(ctx context.Context, dependencies [
d.log(fmt.Sprintf("Warning: failed to write environment config: %v", err)) d.log(fmt.Sprintf("Warning: failed to write environment config: %v", err))
} }
if err := d.WriteWindowManagerConfig(wm); err != nil { if err := d.EnableDMSService(ctx); err != nil {
d.log(fmt.Sprintf("Warning: failed to write window manager config: %v", err))
}
if err := d.EnableDMSService(ctx, wm); err != nil {
d.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err)) d.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err))
} }
@@ -423,7 +449,7 @@ func (d *DebianDistribution) enableOBSRepos(ctx context.Context, obsPkgs []Packa
CommandInfo: fmt.Sprintf("curl & gpg to add key for %s", pkg.RepoURL), CommandInfo: fmt.Sprintf("curl & gpg to add key for %s", pkg.RepoURL),
} }
keyCmd := fmt.Sprintf("bash -c 'rm -f %s && curl -fsSL %s/Release.key | gpg --batch --dearmor -o %s'", keyringPath, baseURL, keyringPath) keyCmd := fmt.Sprintf("curl -fsSL %s/Release.key | gpg --dearmor -o %s", baseURL, keyringPath)
cmd := ExecSudoCommand(ctx, sudoPassword, keyCmd) cmd := ExecSudoCommand(ctx, sudoPassword, keyCmd)
if err := d.runWithProgress(cmd, progressChan, PhaseSystemPackages, 0.18, 0.20); err != nil { if err := d.runWithProgress(cmd, progressChan, PhaseSystemPackages, 0.18, 0.20); err != nil {
return fmt.Errorf("failed to add OBS GPG key for %s: %w", pkg.RepoURL, err) return fmt.Errorf("failed to add OBS GPG key for %s: %w", pkg.RepoURL, err)
@@ -441,7 +467,7 @@ func (d *DebianDistribution) enableOBSRepos(ctx context.Context, obsPkgs []Packa
} }
addRepoCmd := ExecSudoCommand(ctx, sudoPassword, addRepoCmd := ExecSudoCommand(ctx, sudoPassword,
fmt.Sprintf("bash -c \"echo '%s' | tee %s\"", repoLine, listFile)) fmt.Sprintf("echo '%s' | tee %s", repoLine, listFile))
if err := d.runWithProgress(addRepoCmd, progressChan, PhaseSystemPackages, 0.20, 0.22); err != nil { if err := d.runWithProgress(addRepoCmd, progressChan, PhaseSystemPackages, 0.20, 0.22); err != nil {
return fmt.Errorf("failed to add OBS repo %s: %w", pkg.RepoURL, err) return fmt.Errorf("failed to add OBS repo %s: %w", pkg.RepoURL, err)
} }
@@ -476,7 +502,7 @@ func (d *DebianDistribution) installAPTPackages(ctx context.Context, packages []
d.log(fmt.Sprintf("Installing APT packages: %s", strings.Join(packages, ", "))) d.log(fmt.Sprintf("Installing APT packages: %s", strings.Join(packages, ", ")))
args := []string{"DEBIAN_FRONTEND=noninteractive", "apt-get", "install", "-y"} args := []string{"apt-get", "install", "-y"}
args = append(args, packages...) args = append(args, packages...)
progressChan <- InstallProgressMsg{ progressChan <- InstallProgressMsg{
@@ -586,7 +612,7 @@ func (d *DebianDistribution) installRust(ctx context.Context, sudoPassword strin
CommandInfo: "sudo apt-get install rustup", CommandInfo: "sudo apt-get install rustup",
} }
rustupInstallCmd := ExecSudoCommand(ctx, sudoPassword, "DEBIAN_FRONTEND=noninteractive apt-get install -y rustup") rustupInstallCmd := ExecSudoCommand(ctx, sudoPassword, "apt-get install -y rustup")
if err := d.runWithProgress(rustupInstallCmd, progressChan, PhaseSystemPackages, 0.82, 0.83); err != nil { if err := d.runWithProgress(rustupInstallCmd, progressChan, PhaseSystemPackages, 0.82, 0.83); err != nil {
return fmt.Errorf("failed to install rustup: %w", err) return fmt.Errorf("failed to install rustup: %w", err)
} }
@@ -625,7 +651,7 @@ func (d *DebianDistribution) installGo(ctx context.Context, sudoPassword string,
CommandInfo: "sudo apt-get install golang-go", CommandInfo: "sudo apt-get install golang-go",
} }
installCmd := ExecSudoCommand(ctx, sudoPassword, "DEBIAN_FRONTEND=noninteractive apt-get install -y golang-go") installCmd := ExecSudoCommand(ctx, sudoPassword, "apt-get install -y golang-go")
return d.runWithProgress(installCmd, progressChan, PhaseSystemPackages, 0.87, 0.90) return d.runWithProgress(installCmd, progressChan, PhaseSystemPackages, 0.87, 0.90)
} }

View File

@@ -97,7 +97,17 @@ func (f *FedoraDistribution) DetectDependenciesWithTerminal(ctx context.Context,
} }
func (f *FedoraDistribution) detectXDGPortal() deps.Dependency { func (f *FedoraDistribution) detectXDGPortal() deps.Dependency {
return f.detectPackage("xdg-desktop-portal-gtk", "Desktop integration portal for GTK", f.packageInstalled("xdg-desktop-portal-gtk")) status := deps.StatusMissing
if f.packageInstalled("xdg-desktop-portal-gtk") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "xdg-desktop-portal-gtk",
Status: status,
Description: "Desktop integration portal for GTK",
Required: true,
}
} }
func (f *FedoraDistribution) packageInstalled(pkg string) bool { func (f *FedoraDistribution) packageInstalled(pkg string) bool {
@@ -156,7 +166,10 @@ func (f *FedoraDistribution) getDmsMapping(variant deps.PackageVariant) PackageM
return PackageMapping{Name: "dms", Repository: RepoTypeCOPR, RepoURL: "avengemedia/dms"} return PackageMapping{Name: "dms", Repository: RepoTypeCOPR, RepoURL: "avengemedia/dms"}
} }
func (f *FedoraDistribution) getHyprlandMapping(_ deps.PackageVariant) PackageMapping { func (f *FedoraDistribution) getHyprlandMapping(variant deps.PackageVariant) PackageMapping {
if variant == deps.VariantGit {
return PackageMapping{Name: "hyprland-git", Repository: RepoTypeCOPR, RepoURL: "solopasha/hyprland"}
}
return PackageMapping{Name: "hyprland", Repository: RepoTypeCOPR, RepoURL: "solopasha/hyprland"} return PackageMapping{Name: "hyprland", Repository: RepoTypeCOPR, RepoURL: "solopasha/hyprland"}
} }
@@ -164,7 +177,7 @@ func (f *FedoraDistribution) getNiriMapping(variant deps.PackageVariant) Package
if variant == deps.VariantGit { if variant == deps.VariantGit {
return PackageMapping{Name: "niri", Repository: RepoTypeCOPR, RepoURL: "yalter/niri-git"} return PackageMapping{Name: "niri", Repository: RepoTypeCOPR, RepoURL: "yalter/niri-git"}
} }
return PackageMapping{Name: "niri", Repository: RepoTypeCOPR, RepoURL: "yalter/niri"} return PackageMapping{Name: "niri", Repository: RepoTypeSystem}
} }
func (f *FedoraDistribution) detectXwaylandSatellite() deps.Dependency { func (f *FedoraDistribution) detectXwaylandSatellite() deps.Dependency {
@@ -349,11 +362,7 @@ func (f *FedoraDistribution) InstallPackages(ctx context.Context, dependencies [
f.log(fmt.Sprintf("Warning: failed to write environment config: %v", err)) f.log(fmt.Sprintf("Warning: failed to write environment config: %v", err))
} }
if err := f.WriteWindowManagerConfig(wm); err != nil { if err := f.EnableDMSService(ctx); err != nil {
f.log(fmt.Sprintf("Warning: failed to write window manager config: %v", err))
}
if err := f.EnableDMSService(ctx, wm); err != nil {
f.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err)) f.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err))
} }

View File

@@ -95,6 +95,7 @@ func (g *GentooDistribution) DetectDependenciesWithTerminal(ctx context.Context,
dependencies = append(dependencies, g.detectWindowManager(wm)) dependencies = append(dependencies, g.detectWindowManager(wm))
dependencies = append(dependencies, g.detectQuickshell()) dependencies = append(dependencies, g.detectQuickshell())
dependencies = append(dependencies, g.detectXDGPortal()) dependencies = append(dependencies, g.detectXDGPortal())
dependencies = append(dependencies, g.detectPolkitAgent())
dependencies = append(dependencies, g.detectAccountsService()) dependencies = append(dependencies, g.detectAccountsService())
if wm == deps.WindowManagerHyprland { if wm == deps.WindowManagerHyprland {
@@ -113,15 +114,59 @@ func (g *GentooDistribution) DetectDependenciesWithTerminal(ctx context.Context,
} }
func (g *GentooDistribution) detectXDGPortal() deps.Dependency { func (g *GentooDistribution) detectXDGPortal() deps.Dependency {
return g.detectPackage("xdg-desktop-portal-gtk", "Desktop integration portal for GTK", g.packageInstalled("sys-apps/xdg-desktop-portal-gtk")) status := deps.StatusMissing
if g.packageInstalled("sys-apps/xdg-desktop-portal-gtk") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "xdg-desktop-portal-gtk",
Status: status,
Description: "Desktop integration portal for GTK",
Required: true,
}
}
func (g *GentooDistribution) detectPolkitAgent() deps.Dependency {
status := deps.StatusMissing
if g.packageInstalled("mate-extra/mate-polkit") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "mate-polkit",
Status: status,
Description: "PolicyKit authentication agent",
Required: true,
}
} }
func (g *GentooDistribution) detectXwaylandSatellite() deps.Dependency { func (g *GentooDistribution) detectXwaylandSatellite() deps.Dependency {
return g.detectPackage("xwayland-satellite", "Xwayland support", g.packageInstalled("gui-apps/xwayland-satellite")) status := deps.StatusMissing
if g.packageInstalled("gui-apps/xwayland-satellite") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "xwayland-satellite",
Status: status,
Description: "Xwayland support",
Required: true,
}
} }
func (g *GentooDistribution) detectAccountsService() deps.Dependency { func (g *GentooDistribution) detectAccountsService() deps.Dependency {
return g.detectPackage("accountsservice", "D-Bus interface for user account query and manipulation", g.packageInstalled("sys-apps/accountsservice")) status := deps.StatusMissing
if g.packageInstalled("sys-apps/accountsservice") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "accountsservice",
Status: status,
Description: "D-Bus interface for user account query and manipulation",
Required: true,
}
} }
func (g *GentooDistribution) packageInstalled(pkg string) bool { func (g *GentooDistribution) packageInstalled(pkg string) bool {
@@ -142,6 +187,7 @@ func (g *GentooDistribution) GetPackageMappingWithVariants(wm deps.WindowManager
"alacritty": {Name: "x11-terms/alacritty", Repository: RepoTypeSystem, UseFlags: "X wayland"}, "alacritty": {Name: "x11-terms/alacritty", Repository: RepoTypeSystem, UseFlags: "X wayland"},
"wl-clipboard": {Name: "gui-apps/wl-clipboard", Repository: RepoTypeSystem}, "wl-clipboard": {Name: "gui-apps/wl-clipboard", Repository: RepoTypeSystem},
"xdg-desktop-portal-gtk": {Name: "sys-apps/xdg-desktop-portal-gtk", Repository: RepoTypeSystem, UseFlags: "wayland X"}, "xdg-desktop-portal-gtk": {Name: "sys-apps/xdg-desktop-portal-gtk", Repository: RepoTypeSystem, UseFlags: "wayland X"},
"mate-polkit": {Name: "mate-extra/mate-polkit", Repository: RepoTypeSystem},
"accountsservice": {Name: "sys-apps/accountsservice", Repository: RepoTypeSystem}, "accountsservice": {Name: "sys-apps/accountsservice", Repository: RepoTypeSystem},
"qtbase": {Name: "dev-qt/qtbase", Repository: RepoTypeSystem, UseFlags: "wayland opengl vulkan widgets"}, "qtbase": {Name: "dev-qt/qtbase", Repository: RepoTypeSystem, UseFlags: "wayland opengl vulkan widgets"},
@@ -177,8 +223,12 @@ func (g *GentooDistribution) getDmsMapping(_ deps.PackageVariant) PackageMapping
return PackageMapping{Name: "dms", Repository: RepoTypeManual, BuildFunc: "installDankMaterialShell"} return PackageMapping{Name: "dms", Repository: RepoTypeManual, BuildFunc: "installDankMaterialShell"}
} }
func (g *GentooDistribution) getHyprlandMapping(_ deps.PackageVariant) PackageMapping { func (g *GentooDistribution) getHyprlandMapping(variant deps.PackageVariant) PackageMapping {
return PackageMapping{Name: "gui-wm/hyprland", Repository: RepoTypeSystem, UseFlags: "X", AcceptKeywords: g.getArchKeyword()} archKeyword := g.getArchKeyword()
if variant == deps.VariantGit {
return PackageMapping{Name: "gui-wm/hyprland", Repository: RepoTypeGURU, UseFlags: "X", AcceptKeywords: archKeyword}
}
return PackageMapping{Name: "gui-wm/hyprland", Repository: RepoTypeSystem, UseFlags: "X", AcceptKeywords: archKeyword}
} }
func (g *GentooDistribution) getNiriMapping(_ deps.PackageVariant) PackageMapping { func (g *GentooDistribution) getNiriMapping(_ deps.PackageVariant) PackageMapping {
@@ -406,11 +456,7 @@ func (g *GentooDistribution) InstallPackages(ctx context.Context, dependencies [
g.log(fmt.Sprintf("Warning: failed to write environment config: %v", err)) g.log(fmt.Sprintf("Warning: failed to write environment config: %v", err))
} }
if err := g.WriteWindowManagerConfig(wm); err != nil { if err := g.EnableDMSService(ctx); err != nil {
g.log(fmt.Sprintf("Warning: failed to write window manager config: %v", err))
}
if err := g.EnableDMSService(ctx, wm); err != nil {
g.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err)) g.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err))
} }

View File

@@ -87,7 +87,17 @@ func (o *OpenSUSEDistribution) DetectDependenciesWithTerminal(ctx context.Contex
} }
func (o *OpenSUSEDistribution) detectXDGPortal() deps.Dependency { func (o *OpenSUSEDistribution) detectXDGPortal() deps.Dependency {
return o.detectPackage("xdg-desktop-portal-gtk", "Desktop integration portal for GTK", o.packageInstalled("xdg-desktop-portal-gtk")) status := deps.StatusMissing
if o.packageInstalled("xdg-desktop-portal-gtk") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "xdg-desktop-portal-gtk",
Status: status,
Description: "Desktop integration portal for GTK",
Required: true,
}
} }
func (o *OpenSUSEDistribution) packageInstalled(pkg string) bool { func (o *OpenSUSEDistribution) packageInstalled(pkg string) bool {
@@ -367,11 +377,7 @@ func (o *OpenSUSEDistribution) InstallPackages(ctx context.Context, dependencies
o.log(fmt.Sprintf("Warning: failed to write environment config: %v", err)) o.log(fmt.Sprintf("Warning: failed to write environment config: %v", err))
} }
if err := o.WriteWindowManagerConfig(wm); err != nil { if err := o.EnableDMSService(ctx); err != nil {
o.log(fmt.Sprintf("Warning: failed to write window manager config: %v", err))
}
if err := o.EnableDMSService(ctx, wm); err != nil {
o.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err)) o.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err))
} }
@@ -466,7 +472,7 @@ func (o *OpenSUSEDistribution) enableOBSRepos(ctx context.Context, obsPkgs []Pac
cmd := ExecSudoCommand(ctx, sudoPassword, cmd := ExecSudoCommand(ctx, sudoPassword,
fmt.Sprintf("zypper addrepo -f %s", repoURL)) fmt.Sprintf("zypper addrepo -f %s", repoURL))
if err := o.runWithProgress(cmd, progressChan, PhaseSystemPackages, 0.20, 0.22); err != nil { if err := o.runWithProgress(cmd, progressChan, PhaseSystemPackages, 0.20, 0.22); err != nil {
o.log(fmt.Sprintf("OBS repo %s add failed (may already exist): %v", pkg.RepoURL, err)) return fmt.Errorf("failed to enable OBS repo %s: %w", pkg.RepoURL, err)
} }
enabledRepos[pkg.RepoURL] = true enabledRepos[pkg.RepoURL] = true

View File

@@ -85,15 +85,45 @@ func (u *UbuntuDistribution) DetectDependenciesWithTerminal(ctx context.Context,
} }
func (u *UbuntuDistribution) detectXDGPortal() deps.Dependency { func (u *UbuntuDistribution) detectXDGPortal() deps.Dependency {
return u.detectPackage("xdg-desktop-portal-gtk", "Desktop integration portal for GTK", u.packageInstalled("xdg-desktop-portal-gtk")) status := deps.StatusMissing
if u.packageInstalled("xdg-desktop-portal-gtk") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "xdg-desktop-portal-gtk",
Status: status,
Description: "Desktop integration portal for GTK",
Required: true,
}
} }
func (u *UbuntuDistribution) detectXwaylandSatellite() deps.Dependency { func (u *UbuntuDistribution) detectXwaylandSatellite() deps.Dependency {
return u.detectCommand("xwayland-satellite", "Xwayland support") status := deps.StatusMissing
if u.commandExists("xwayland-satellite") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "xwayland-satellite",
Status: status,
Description: "Xwayland support",
Required: true,
}
} }
func (u *UbuntuDistribution) detectAccountsService() deps.Dependency { func (u *UbuntuDistribution) detectAccountsService() deps.Dependency {
return u.detectPackage("accountsservice", "D-Bus interface for user account query and manipulation", u.packageInstalled("accountsservice")) status := deps.StatusMissing
if u.packageInstalled("accountsservice") {
status = deps.StatusInstalled
}
return deps.Dependency{
Name: "accountsservice",
Status: status,
Description: "D-Bus interface for user account query and manipulation",
Required: true,
}
} }
func (u *UbuntuDistribution) packageInstalled(pkg string) bool { func (u *UbuntuDistribution) packageInstalled(pkg string) bool {
@@ -327,11 +357,7 @@ func (u *UbuntuDistribution) InstallPackages(ctx context.Context, dependencies [
u.log(fmt.Sprintf("Warning: failed to write environment config: %v", err)) u.log(fmt.Sprintf("Warning: failed to write environment config: %v", err))
} }
if err := u.WriteWindowManagerConfig(wm); err != nil { if err := u.EnableDMSService(ctx); err != nil {
u.log(fmt.Sprintf("Warning: failed to write window manager config: %v", err))
}
if err := u.EnableDMSService(ctx, wm); err != nil {
u.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err)) u.log(fmt.Sprintf("Warning: failed to enable dms service: %v", err))
} }

View File

@@ -286,13 +286,6 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, loadInstalledPlugins return m, loadInstalledPlugins
} }
return m, nil return m, nil
case pluginUpdatedMsg:
if msg.err != nil {
m.installedPluginsError = msg.err.Error()
} else {
m.installedPluginsError = ""
}
return m, nil
case pluginInstalledMsg: case pluginInstalledMsg:
if msg.err != nil { if msg.err != nil {
m.pluginsError = msg.err.Error() m.pluginsError = msg.err.Error()

View File

@@ -75,13 +75,14 @@ type MenuItem struct {
func NewModel(version string) Model { func NewModel(version string) Model {
detector, _ := NewDetector() detector, _ := NewDetector()
dependencies := detector.GetInstalledComponents()
var dependencies []DependencyInfo // Use the proper detection method for both window managers
var hyprlandInstalled, niriInstalled bool hyprlandInstalled, niriInstalled, err := detector.GetWindowManagerStatus()
if err != nil {
if detector != nil { // Fallback to false if detection fails
dependencies = detector.GetInstalledComponents() hyprlandInstalled = false
hyprlandInstalled, niriInstalled, _ = detector.GetWindowManagerStatus() niriInstalled = false
} }
m := Model{ m := Model{
@@ -200,13 +201,6 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, loadInstalledPlugins return m, loadInstalledPlugins
} }
return m, nil return m, nil
case pluginUpdatedMsg:
if msg.err != nil {
m.installedPluginsError = msg.err.Error()
} else {
m.installedPluginsError = ""
}
return m, nil
case pluginInstalledMsg: case pluginInstalledMsg:
if msg.err != nil { if msg.err != nil {
m.pluginsError = msg.err.Error() m.pluginsError = msg.err.Error()

View File

@@ -227,11 +227,6 @@ func (m Model) updatePluginInstalledDetail(msg tea.KeyMsg) (tea.Model, tea.Cmd)
plugin := m.installedPluginsList[m.selectedInstalledIndex] plugin := m.installedPluginsList[m.selectedInstalledIndex]
return m, uninstallPlugin(plugin) return m, uninstallPlugin(plugin)
} }
case "p":
if m.selectedInstalledIndex < len(m.installedPluginsList) {
plugin := m.installedPluginsList[m.selectedInstalledIndex]
return m, updatePlugin(plugin)
}
} }
return m, nil return m, nil
} }
@@ -251,11 +246,6 @@ type pluginInstalledMsg struct {
err error err error
} }
type pluginUpdatedMsg struct {
pluginName string
err error
}
func loadInstalledPlugins() tea.Msg { func loadInstalledPlugins() tea.Msg {
manager, err := plugins.NewManager() manager, err := plugins.NewManager()
if err != nil { if err != nil {
@@ -347,31 +337,3 @@ func uninstallPlugin(plugin pluginInfo) tea.Cmd {
return pluginUninstalledMsg{pluginName: plugin.Name} return pluginUninstalledMsg{pluginName: plugin.Name}
} }
} }
func updatePlugin(plugin pluginInfo) tea.Cmd {
return func() tea.Msg {
manager, err := plugins.NewManager()
if err != nil {
return pluginUpdatedMsg{pluginName: plugin.Name, err: err}
}
p := plugins.Plugin{
ID: plugin.ID,
Name: plugin.Name,
Category: plugin.Category,
Author: plugin.Author,
Description: plugin.Description,
Repo: plugin.Repo,
Path: plugin.Path,
Capabilities: plugin.Capabilities,
Compositors: plugin.Compositors,
Dependencies: plugin.Dependencies,
}
if err := manager.Update(p); err != nil {
return pluginUpdatedMsg{pluginName: plugin.Name, err: err}
}
return pluginUpdatedMsg{pluginName: plugin.Name}
}
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/config" "github.com/AvengeMedia/DankMaterialShell/core/internal/config"
"github.com/AvengeMedia/DankMaterialShell/core/internal/distros" "github.com/AvengeMedia/DankMaterialShell/core/internal/distros"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
// DetectDMSPath checks for DMS installation following XDG Base Directory specification // DetectDMSPath checks for DMS installation following XDG Base Directory specification
@@ -23,10 +22,10 @@ func DetectDMSPath() (string, error) {
func DetectCompositors() []string { func DetectCompositors() []string {
var compositors []string var compositors []string
if utils.CommandExists("niri") { if commandExists("niri") {
compositors = append(compositors, "niri") compositors = append(compositors, "niri")
} }
if utils.CommandExists("Hyprland") { if commandExists("Hyprland") {
compositors = append(compositors, "Hyprland") compositors = append(compositors, "Hyprland")
} }
@@ -63,7 +62,7 @@ func PromptCompositorChoice(compositors []string) (string, error) {
// EnsureGreetdInstalled checks if greetd is installed and installs it if not // EnsureGreetdInstalled checks if greetd is installed and installs it if not
func EnsureGreetdInstalled(logFunc func(string), sudoPassword string) error { func EnsureGreetdInstalled(logFunc func(string), sudoPassword string) error {
if utils.CommandExists("greetd") { if commandExists("greetd") {
logFunc("✓ greetd is already installed") logFunc("✓ greetd is already installed")
return nil return nil
} }
@@ -145,7 +144,7 @@ func EnsureGreetdInstalled(logFunc func(string), sudoPassword string) error {
// CopyGreeterFiles installs the dms-greeter wrapper and sets up cache directory // CopyGreeterFiles installs the dms-greeter wrapper and sets up cache directory
func CopyGreeterFiles(dmsPath, compositor string, logFunc func(string), sudoPassword string) error { func CopyGreeterFiles(dmsPath, compositor string, logFunc func(string), sudoPassword string) error {
// Check if dms-greeter is already in PATH // Check if dms-greeter is already in PATH
if utils.CommandExists("dms-greeter") { if commandExists("dms-greeter") {
logFunc("✓ dms-greeter wrapper already installed") logFunc("✓ dms-greeter wrapper already installed")
} else { } else {
// Install the wrapper script // Install the wrapper script
@@ -205,7 +204,7 @@ func CopyGreeterFiles(dmsPath, compositor string, logFunc func(string), sudoPass
// SetupParentDirectoryACLs sets ACLs on parent directories to allow traversal // SetupParentDirectoryACLs sets ACLs on parent directories to allow traversal
func SetupParentDirectoryACLs(logFunc func(string), sudoPassword string) error { func SetupParentDirectoryACLs(logFunc func(string), sudoPassword string) error {
if !utils.CommandExists("setfacl") { if !commandExists("setfacl") {
logFunc("⚠ Warning: setfacl command not found. ACL support may not be available on this filesystem.") logFunc("⚠ Warning: setfacl command not found. ACL support may not be available on this filesystem.")
logFunc(" If theme sync doesn't work, you may need to install acl package:") logFunc(" If theme sync doesn't work, you may need to install acl package:")
logFunc(" - Fedora/RHEL: sudo dnf install acl") logFunc(" - Fedora/RHEL: sudo dnf install acl")
@@ -420,7 +419,7 @@ user = "greeter"
// Determine wrapper command path // Determine wrapper command path
wrapperCmd := "dms-greeter" wrapperCmd := "dms-greeter"
if !utils.CommandExists("dms-greeter") { if !commandExists("dms-greeter") {
wrapperCmd = "/usr/local/bin/dms-greeter" wrapperCmd = "/usr/local/bin/dms-greeter"
} }
@@ -487,3 +486,8 @@ func runSudoCmd(sudoPassword string, command string, args ...string) error {
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
return cmd.Run() return cmd.Run()
} }
func commandExists(cmd string) bool {
_, err := exec.LookPath(cmd)
return err == nil
}

View File

@@ -5,8 +5,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
type DiscoveryConfig struct { type DiscoveryConfig struct {
@@ -16,7 +14,13 @@ type DiscoveryConfig struct {
func DefaultDiscoveryConfig() *DiscoveryConfig { func DefaultDiscoveryConfig() *DiscoveryConfig {
var searchPaths []string var searchPaths []string
configHome := utils.XDGConfigHome() configHome := os.Getenv("XDG_CONFIG_HOME")
if configHome == "" {
if homeDir, err := os.UserHomeDir(); err == nil {
configHome = filepath.Join(homeDir, ".config")
}
}
if configHome != "" { if configHome != "" {
searchPaths = append(searchPaths, filepath.Join(configHome, "DankMaterialShell", "cheatsheets")) searchPaths = append(searchPaths, filepath.Join(configHome, "DankMaterialShell", "cheatsheets"))
} }
@@ -39,7 +43,7 @@ func (d *DiscoveryConfig) FindJSONFiles() ([]string, error) {
var files []string var files []string
for _, searchPath := range d.SearchPaths { for _, searchPath := range d.SearchPaths {
expandedPath, err := utils.ExpandPath(searchPath) expandedPath, err := expandPath(searchPath)
if err != nil { if err != nil {
continue continue
} }
@@ -70,6 +74,20 @@ func (d *DiscoveryConfig) FindJSONFiles() ([]string, error) {
return files, nil return files, nil
} }
func expandPath(path string) (string, error) {
expandedPath := os.ExpandEnv(path)
if strings.HasPrefix(expandedPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
expandedPath = filepath.Join(home, expandedPath[1:])
}
return filepath.Clean(expandedPath), nil
}
type JSONProviderFactory func(filePath string) (Provider, error) type JSONProviderFactory func(filePath string) (Provider, error)
var jsonProviderFactory JSONProviderFactory var jsonProviderFactory JSONProviderFactory

View File

@@ -4,8 +4,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
func TestDefaultDiscoveryConfig(t *testing.T) { func TestDefaultDiscoveryConfig(t *testing.T) {
@@ -274,13 +272,13 @@ func TestExpandPathInDiscovery(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, err := utils.ExpandPath(tt.input) result, err := expandPath(tt.input)
if err != nil { if err != nil {
t.Fatalf("expandPath failed: %v", err) t.Fatalf("expandPath failed: %v", err)
} }
if result != tt.expected { if result != tt.expected {
t.Errorf("utils.ExpandPath(%q) = %q, want %q", tt.input, result, tt.expected) t.Errorf("expandPath(%q) = %q, want %q", tt.input, result, tt.expected)
} }
}) })
} }

View File

@@ -5,8 +5,6 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
const ( const (
@@ -44,9 +42,14 @@ func NewHyprlandParser() *HyprlandParser {
} }
func (p *HyprlandParser) ReadContent(directory string) error { func (p *HyprlandParser) ReadContent(directory string) error {
expandedDir, err := utils.ExpandPath(directory) expandedDir := os.ExpandEnv(directory)
if err != nil { expandedDir = filepath.Clean(expandedDir)
return err if strings.HasPrefix(expandedDir, "~") {
home, err := os.UserHomeDir()
if err != nil {
return err
}
expandedDir = filepath.Join(home, expandedDir[1:])
} }
info, err := os.Stat(expandedDir) info, err := os.Stat(expandedDir)

View File

@@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/keybinds" "github.com/AvengeMedia/DankMaterialShell/core/internal/keybinds"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
type JSONFileProvider struct { type JSONFileProvider struct {
@@ -20,7 +20,7 @@ func NewJSONFileProvider(filePath string) (*JSONFileProvider, error) {
return nil, fmt.Errorf("file path cannot be empty") return nil, fmt.Errorf("file path cannot be empty")
} }
expandedPath, err := utils.ExpandPath(filePath) expandedPath, err := expandPath(filePath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to expand path: %w", err) return nil, fmt.Errorf("failed to expand path: %w", err)
} }
@@ -117,3 +117,17 @@ func (j *JSONFileProvider) GetCheatSheet() (*keybinds.CheatSheet, error) {
Binds: categorizedBinds, Binds: categorizedBinds,
}, nil }, nil
} }
func expandPath(path string) (string, error) {
expandedPath := os.ExpandEnv(path)
if strings.HasPrefix(expandedPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
expandedPath = filepath.Join(home, expandedPath[1:])
}
return filepath.Clean(expandedPath), nil
}

View File

@@ -4,8 +4,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
func TestNewJSONFileProvider(t *testing.T) { func TestNewJSONFileProvider(t *testing.T) {
@@ -268,13 +266,13 @@ func TestExpandPath(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, err := utils.ExpandPath(tt.input) result, err := expandPath(tt.input)
if err != nil { if err != nil {
t.Fatalf("expandPath failed: %v", err) t.Fatalf("expandPath failed: %v", err)
} }
if result != tt.expected { if result != tt.expected {
t.Errorf("utils.ExpandPath(%q) = %q, want %q", tt.input, result, tt.expected) t.Errorf("expandPath(%q) = %q, want %q", tt.input, result, tt.expected)
} }
}) })
} }

View File

@@ -5,8 +5,6 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
const ( const (
@@ -36,9 +34,14 @@ func NewMangoWCParser() *MangoWCParser {
} }
func (p *MangoWCParser) ReadContent(path string) error { func (p *MangoWCParser) ReadContent(path string) error {
expandedPath, err := utils.ExpandPath(path) expandedPath := os.ExpandEnv(path)
if err != nil { expandedPath = filepath.Clean(expandedPath)
return err if strings.HasPrefix(expandedPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
return err
}
expandedPath = filepath.Join(home, expandedPath[1:])
} }
info, err := os.Stat(expandedPath) info, err := os.Stat(expandedPath)

View File

@@ -6,11 +6,9 @@ import (
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"sort" "sort"
"strconv"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/keybinds" "github.com/AvengeMedia/DankMaterialShell/core/internal/keybinds"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
"github.com/sblinch/kdl-go" "github.com/sblinch/kdl-go"
"github.com/sblinch/kdl-go/document" "github.com/sblinch/kdl-go/document"
) )
@@ -31,7 +29,15 @@ func NewNiriProvider(configDir string) *NiriProvider {
} }
func defaultNiriConfigDir() string { func defaultNiriConfigDir() string {
return filepath.Join(utils.XDGConfigHome(), "niri") if configHome := os.Getenv("XDG_CONFIG_HOME"); configHome != "" {
return filepath.Join(configHome, "niri")
}
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".config", "niri")
} }
func (n *NiriProvider) Name() string { func (n *NiriProvider) Name() string {
@@ -148,13 +154,11 @@ func (n *NiriProvider) convertKeybind(kb *NiriKeyBinding, subcategory string, co
} }
bind := keybinds.Keybind{ bind := keybinds.Keybind{
Key: keyStr, Key: keyStr,
Description: kb.Description, Description: kb.Description,
Action: rawAction, Action: rawAction,
Subcategory: subcategory, Subcategory: subcategory,
Source: source, Source: source,
HideOnOverlay: kb.HideOnOverlay,
CooldownMs: kb.CooldownMs,
} }
if source == "dms" && conflicts != nil { if source == "dms" && conflicts != nil {
@@ -312,9 +316,7 @@ func (n *NiriProvider) extractOptions(node *document.Node) map[string]any {
opts["repeat"] = val.String() == "true" opts["repeat"] = val.String() == "true"
} }
if val, ok := node.Properties.Get("cooldown-ms"); ok { if val, ok := node.Properties.Get("cooldown-ms"); ok {
if ms, err := strconv.Atoi(val.String()); err == nil { opts["cooldown-ms"] = val.String()
opts["cooldown-ms"] = ms
}
} }
if val, ok := node.Properties.Get("allow-when-locked"); ok { if val, ok := node.Properties.Get("allow-when-locked"); ok {
opts["allow-when-locked"] = val.String() == "true" opts["allow-when-locked"] = val.String() == "true"
@@ -340,14 +342,7 @@ func (n *NiriProvider) buildBindNode(bind *overrideBind) *document.Node {
node.AddProperty("repeat", false, "") node.AddProperty("repeat", false, "")
} }
if v, ok := bind.Options["cooldown-ms"]; ok { if v, ok := bind.Options["cooldown-ms"]; ok {
switch val := v.(type) { node.AddProperty("cooldown-ms", v, "")
case int:
node.AddProperty("cooldown-ms", val, "")
case string:
if ms, err := strconv.Atoi(val); err == nil {
node.AddProperty("cooldown-ms", ms, "")
}
}
} }
if v, ok := bind.Options["allow-when-locked"]; ok && v == true { if v, ok := bind.Options["allow-when-locked"]; ok && v == true {
node.AddProperty("allow-when-locked", true, "") node.AddProperty("allow-when-locked", true, "")

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"github.com/sblinch/kdl-go" "github.com/sblinch/kdl-go"
@@ -12,14 +11,12 @@ import (
) )
type NiriKeyBinding struct { type NiriKeyBinding struct {
Mods []string Mods []string
Key string Key string
Action string Action string
Args []string Args []string
Description string Description string
HideOnOverlay bool Source string
CooldownMs int
Source string
} }
type NiriSection struct { type NiriSection struct {
@@ -276,31 +273,19 @@ func (p *NiriParser) parseKeybindNode(node *document.Node, _ string) *NiriKeyBin
} }
var description string var description string
var hideOnOverlay bool
var cooldownMs int
if node.Properties != nil { if node.Properties != nil {
if val, ok := node.Properties.Get("hotkey-overlay-title"); ok { if val, ok := node.Properties.Get("hotkey-overlay-title"); ok {
switch val.ValueString() { description = val.ValueString()
case "null", "":
hideOnOverlay = true
default:
description = val.ValueString()
}
}
if val, ok := node.Properties.Get("cooldown-ms"); ok {
cooldownMs, _ = strconv.Atoi(val.String())
} }
} }
return &NiriKeyBinding{ return &NiriKeyBinding{
Mods: mods, Mods: mods,
Key: key, Key: key,
Action: action, Action: action,
Args: args, Args: args,
Description: description, Description: description,
HideOnOverlay: hideOnOverlay, Source: p.currentSource,
CooldownMs: cooldownMs,
Source: p.currentSource,
} }
} }

View File

@@ -2,7 +2,6 @@ package providers
import ( import (
"fmt" "fmt"
"os"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/keybinds" "github.com/AvengeMedia/DankMaterialShell/core/internal/keybinds"
@@ -10,42 +9,18 @@ import (
type SwayProvider struct { type SwayProvider struct {
configPath string configPath string
isScroll bool
} }
func NewSwayProvider(configPath string) *SwayProvider { func NewSwayProvider(configPath string) *SwayProvider {
isScroll := false
_, scrollEnvSet := os.LookupEnv("SCROLLSOCK")
if configPath == "" { if configPath == "" {
if scrollEnvSet { configPath = "$HOME/.config/sway"
configPath = "$HOME/.config/scroll"
isScroll = true
} else {
configPath = "$HOME/.config/sway"
}
} else {
// Determine isScroll based on the provided config path
isScroll = strings.Contains(configPath, "scroll")
} }
return &SwayProvider{ return &SwayProvider{
configPath: configPath, configPath: configPath,
isScroll: isScroll,
} }
} }
func (s *SwayProvider) Name() string { func (s *SwayProvider) Name() string {
if s != nil && s.isScroll {
return "scroll"
}
if s == nil {
_, ok := os.LookupEnv("SCROLLSOCK")
if ok {
return "scroll"
}
}
return "sway" return "sway"
} }
@@ -58,13 +33,8 @@ func (s *SwayProvider) GetCheatSheet() (*keybinds.CheatSheet, error) {
categorizedBinds := make(map[string][]keybinds.Keybind) categorizedBinds := make(map[string][]keybinds.Keybind)
s.convertSection(section, "", categorizedBinds) s.convertSection(section, "", categorizedBinds)
cheatSheetTitle := "Sway Keybinds"
if s != nil && s.isScroll {
cheatSheetTitle = "Scroll Keybinds"
}
return &keybinds.CheatSheet{ return &keybinds.CheatSheet{
Title: cheatSheetTitle, Title: "Sway Keybinds",
Provider: s.Name(), Provider: s.Name(),
Binds: categorizedBinds, Binds: categorizedBinds,
}, nil }, nil

View File

@@ -5,8 +5,6 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
const ( const (
@@ -44,9 +42,14 @@ func NewSwayParser() *SwayParser {
} }
func (p *SwayParser) ReadContent(path string) error { func (p *SwayParser) ReadContent(path string) error {
expandedPath, err := utils.ExpandPath(path) expandedPath := os.ExpandEnv(path)
if err != nil { expandedPath = filepath.Clean(expandedPath)
return err if strings.HasPrefix(expandedPath, "~") {
home, err := os.UserHomeDir()
if err != nil {
return err
}
expandedPath = filepath.Join(home, expandedPath[1:])
} }
info, err := os.Stat(expandedPath) info, err := os.Stat(expandedPath)

View File

@@ -1,14 +1,12 @@
package keybinds package keybinds
type Keybind struct { type Keybind struct {
Key string `json:"key"` Key string `json:"key"`
Description string `json:"desc"` Description string `json:"desc"`
Action string `json:"action,omitempty"` Action string `json:"action,omitempty"`
Subcategory string `json:"subcat,omitempty"` Subcategory string `json:"subcat,omitempty"`
Source string `json:"source,omitempty"` Source string `json:"source,omitempty"`
HideOnOverlay bool `json:"hideOnOverlay,omitempty"` Conflict *Keybind `json:"conflict,omitempty"`
CooldownMs int `json:"cooldownMs,omitempty"`
Conflict *Keybind `json:"conflict,omitempty"`
} }
type DMSBindsStatus struct { type DMSBindsStatus struct {

View File

@@ -13,7 +13,6 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/dank16" "github.com/AvengeMedia/DankMaterialShell/core/internal/dank16"
"github.com/AvengeMedia/DankMaterialShell/core/internal/log" "github.com/AvengeMedia/DankMaterialShell/core/internal/log"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
var ( var (
@@ -278,7 +277,7 @@ func appendConfig(opts *Options, cfgFile *os.File, checkCmd, fileName string) {
if _, err := os.Stat(configPath); err != nil { if _, err := os.Stat(configPath); err != nil {
return return
} }
if checkCmd != "skip" && !utils.CommandExists(checkCmd) { if checkCmd != "skip" && !commandExists(checkCmd) {
return return
} }
data, err := os.ReadFile(configPath) data, err := os.ReadFile(configPath)
@@ -294,7 +293,7 @@ func appendTerminalConfig(opts *Options, cfgFile *os.File, tmpDir, checkCmd, fil
if _, err := os.Stat(configPath); err != nil { if _, err := os.Stat(configPath); err != nil {
return return
} }
if checkCmd != "skip" && !utils.CommandExists(checkCmd) { if checkCmd != "skip" && !commandExists(checkCmd) {
return return
} }
data, err := os.ReadFile(configPath) data, err := os.ReadFile(configPath)
@@ -391,6 +390,11 @@ func extractTOMLSection(content, startMarker, endMarker string) string {
return content[startIdx : startIdx+endIdx] return content[startIdx : startIdx+endIdx]
} }
func commandExists(name string) bool {
_, err := exec.LookPath(name)
return err == nil
}
func checkMatugenVersion() { func checkMatugenVersion() {
matugenVersionOnce.Do(func() { matugenVersionOnce.Do(func() {
cmd := exec.Command("matugen", "--version") cmd := exec.Command("matugen", "--version")

View File

@@ -9,7 +9,6 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@@ -33,70 +32,33 @@ func NewManagerWithFs(fs afero.Fs) (*Manager, error) {
} }
func getPluginsDir() string { func getPluginsDir() string {
return filepath.Join(utils.XDGConfigHome(), "DankMaterialShell", "plugins") configHome := os.Getenv("XDG_CONFIG_HOME")
if configHome == "" {
homeDir, err := os.UserHomeDir()
if err != nil {
return filepath.Join(os.TempDir(), "DankMaterialShell", "plugins")
}
configHome = filepath.Join(homeDir, ".config")
}
return filepath.Join(configHome, "DankMaterialShell", "plugins")
} }
func (m *Manager) IsInstalled(plugin Plugin) (bool, error) { func (m *Manager) IsInstalled(plugin Plugin) (bool, error) {
path, err := m.findInstalledPath(plugin.ID) pluginPath := filepath.Join(m.pluginsDir, plugin.ID)
exists, err := afero.DirExists(m.fs, pluginPath)
if err != nil { if err != nil {
return false, err return false, err
} }
return path != "", nil if exists {
} return true, nil
}
func (m *Manager) findInstalledPath(pluginID string) (string, error) { systemPluginPath := filepath.Join("/etc/xdg/quickshell/dms-plugins", plugin.ID)
// Check user plugins directory systemExists, err := afero.DirExists(m.fs, systemPluginPath)
path, err := m.findInDir(m.pluginsDir, pluginID)
if err != nil { if err != nil {
return "", err return false, err
} }
if path != "" { return systemExists, nil
return path, nil
}
// Check system plugins directory
systemDir := "/etc/xdg/quickshell/dms-plugins"
return m.findInDir(systemDir, pluginID)
}
func (m *Manager) findInDir(dir, pluginID string) (string, error) {
// First, check if folder with exact ID name exists
exactPath := filepath.Join(dir, pluginID)
if exists, _ := afero.DirExists(m.fs, exactPath); exists {
return exactPath, nil
}
// Scan all folders and check plugin.json for matching ID
exists, err := afero.DirExists(m.fs, dir)
if err != nil || !exists {
return "", nil
}
entries, err := afero.ReadDir(m.fs, dir)
if err != nil {
return "", nil
}
for _, entry := range entries {
name := entry.Name()
if name == ".repos" || strings.HasSuffix(name, ".meta") {
continue
}
fullPath := filepath.Join(dir, name)
isPlugin := entry.IsDir() || entry.Mode()&os.ModeSymlink != 0
if !isPlugin {
if info, err := m.fs.Stat(fullPath); err == nil && info.IsDir() {
isPlugin = true
}
}
if isPlugin && m.getPluginID(fullPath) == pluginID {
return fullPath, nil
}
}
return "", nil
} }
func (m *Manager) Install(plugin Plugin) error { func (m *Manager) Install(plugin Plugin) error {
@@ -189,19 +151,25 @@ func (m *Manager) createSymlink(source, dest string) error {
} }
func (m *Manager) Update(plugin Plugin) error { func (m *Manager) Update(plugin Plugin) error {
pluginPath, err := m.findInstalledPath(plugin.ID) pluginPath := filepath.Join(m.pluginsDir, plugin.ID)
exists, err := afero.DirExists(m.fs, pluginPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to find plugin: %w", err) return fmt.Errorf("failed to check if plugin exists: %w", err)
} }
if pluginPath == "" { if !exists {
systemPluginPath := filepath.Join("/etc/xdg/quickshell/dms-plugins", plugin.ID)
systemExists, err := afero.DirExists(m.fs, systemPluginPath)
if err != nil {
return fmt.Errorf("failed to check if plugin exists: %w", err)
}
if systemExists {
return fmt.Errorf("cannot update system plugin: %s", plugin.Name)
}
return fmt.Errorf("plugin not installed: %s", plugin.Name) return fmt.Errorf("plugin not installed: %s", plugin.Name)
} }
if strings.HasPrefix(pluginPath, "/etc/xdg/quickshell/dms-plugins") {
return fmt.Errorf("cannot update system plugin: %s", plugin.Name)
}
metaPath := pluginPath + ".meta" metaPath := pluginPath + ".meta"
metaExists, err := afero.Exists(m.fs, metaPath) metaExists, err := afero.Exists(m.fs, metaPath)
if err != nil { if err != nil {
@@ -241,19 +209,25 @@ func (m *Manager) Update(plugin Plugin) error {
} }
func (m *Manager) Uninstall(plugin Plugin) error { func (m *Manager) Uninstall(plugin Plugin) error {
pluginPath, err := m.findInstalledPath(plugin.ID) pluginPath := filepath.Join(m.pluginsDir, plugin.ID)
exists, err := afero.DirExists(m.fs, pluginPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to find plugin: %w", err) return fmt.Errorf("failed to check if plugin exists: %w", err)
} }
if pluginPath == "" { if !exists {
systemPluginPath := filepath.Join("/etc/xdg/quickshell/dms-plugins", plugin.ID)
systemExists, err := afero.DirExists(m.fs, systemPluginPath)
if err != nil {
return fmt.Errorf("failed to check if plugin exists: %w", err)
}
if systemExists {
return fmt.Errorf("cannot uninstall system plugin: %s", plugin.Name)
}
return fmt.Errorf("plugin not installed: %s", plugin.Name) return fmt.Errorf("plugin not installed: %s", plugin.Name)
} }
if strings.HasPrefix(pluginPath, "/etc/xdg/quickshell/dms-plugins") {
return fmt.Errorf("cannot uninstall system plugin: %s", plugin.Name)
}
metaPath := pluginPath + ".meta" metaPath := pluginPath + ".meta"
metaExists, err := afero.Exists(m.fs, metaPath) metaExists, err := afero.Exists(m.fs, metaPath)
if err != nil { if err != nil {
@@ -395,174 +369,47 @@ func (m *Manager) ListInstalled() ([]string, error) {
// getPluginID reads the plugin.json file and returns the plugin ID // getPluginID reads the plugin.json file and returns the plugin ID
func (m *Manager) getPluginID(pluginPath string) string { func (m *Manager) getPluginID(pluginPath string) string {
manifest := m.getPluginManifest(pluginPath)
if manifest == nil {
return ""
}
return manifest.ID
}
func (m *Manager) getPluginManifest(pluginPath string) *pluginManifest {
manifestPath := filepath.Join(pluginPath, "plugin.json") manifestPath := filepath.Join(pluginPath, "plugin.json")
data, err := afero.ReadFile(m.fs, manifestPath) data, err := afero.ReadFile(m.fs, manifestPath)
if err != nil { if err != nil {
return nil return ""
} }
var manifest pluginManifest var manifest struct {
ID string `json:"id"`
}
if err := json.Unmarshal(data, &manifest); err != nil { if err := json.Unmarshal(data, &manifest); err != nil {
return nil return ""
} }
return &manifest return manifest.ID
}
type pluginManifest struct {
ID string `json:"id"`
Name string `json:"name"`
} }
func (m *Manager) GetPluginsDir() string { func (m *Manager) GetPluginsDir() string {
return m.pluginsDir return m.pluginsDir
} }
func (m *Manager) UninstallByIDOrName(idOrName string) error {
pluginPath, err := m.findInstalledPathByIDOrName(idOrName)
if err != nil {
return err
}
if pluginPath == "" {
return fmt.Errorf("plugin not found: %s", idOrName)
}
if strings.HasPrefix(pluginPath, "/etc/xdg/quickshell/dms-plugins") {
return fmt.Errorf("cannot uninstall system plugin: %s", idOrName)
}
metaPath := pluginPath + ".meta"
metaExists, _ := afero.Exists(m.fs, metaPath)
if metaExists {
if err := m.fs.Remove(pluginPath); err != nil {
return fmt.Errorf("failed to remove symlink: %w", err)
}
if err := m.fs.Remove(metaPath); err != nil {
return fmt.Errorf("failed to remove metadata: %w", err)
}
} else {
if err := m.fs.RemoveAll(pluginPath); err != nil {
return fmt.Errorf("failed to remove plugin: %w", err)
}
}
return nil
}
func (m *Manager) UpdateByIDOrName(idOrName string) error {
pluginPath, err := m.findInstalledPathByIDOrName(idOrName)
if err != nil {
return err
}
if pluginPath == "" {
return fmt.Errorf("plugin not found: %s", idOrName)
}
if strings.HasPrefix(pluginPath, "/etc/xdg/quickshell/dms-plugins") {
return fmt.Errorf("cannot update system plugin: %s", idOrName)
}
metaPath := pluginPath + ".meta"
metaExists, _ := afero.Exists(m.fs, metaPath)
if metaExists {
// Plugin is from monorepo, but we don't know the repo URL without registry
// Just try to pull from existing .git in the symlink target
return fmt.Errorf("cannot update monorepo plugin without registry info: %s", idOrName)
}
// Standalone plugin - just pull
if err := m.gitClient.Pull(pluginPath); err != nil {
return fmt.Errorf("failed to update plugin: %w", err)
}
return nil
}
func (m *Manager) findInstalledPathByIDOrName(idOrName string) (string, error) {
path, err := m.findInDirByIDOrName(m.pluginsDir, idOrName)
if err != nil {
return "", err
}
if path != "" {
return path, nil
}
systemDir := "/etc/xdg/quickshell/dms-plugins"
return m.findInDirByIDOrName(systemDir, idOrName)
}
func (m *Manager) findInDirByIDOrName(dir, idOrName string) (string, error) {
// Check exact folder name match first
exactPath := filepath.Join(dir, idOrName)
if exists, _ := afero.DirExists(m.fs, exactPath); exists {
return exactPath, nil
}
exists, err := afero.DirExists(m.fs, dir)
if err != nil || !exists {
return "", nil
}
entries, err := afero.ReadDir(m.fs, dir)
if err != nil {
return "", nil
}
for _, entry := range entries {
name := entry.Name()
if name == ".repos" || strings.HasSuffix(name, ".meta") {
continue
}
fullPath := filepath.Join(dir, name)
isPlugin := entry.IsDir() || entry.Mode()&os.ModeSymlink != 0
if !isPlugin {
if info, err := m.fs.Stat(fullPath); err == nil && info.IsDir() {
isPlugin = true
}
}
if !isPlugin {
continue
}
manifest := m.getPluginManifest(fullPath)
if manifest == nil {
continue
}
if manifest.ID == idOrName || manifest.Name == idOrName {
return fullPath, nil
}
}
return "", nil
}
func (m *Manager) HasUpdates(pluginID string, plugin Plugin) (bool, error) { func (m *Manager) HasUpdates(pluginID string, plugin Plugin) (bool, error) {
pluginPath, err := m.findInstalledPath(pluginID) pluginPath := filepath.Join(m.pluginsDir, pluginID)
exists, err := afero.DirExists(m.fs, pluginPath)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to find plugin: %w", err) return false, fmt.Errorf("failed to check if plugin exists: %w", err)
} }
if pluginPath == "" { if !exists {
systemPluginPath := filepath.Join("/etc/xdg/quickshell/dms-plugins", pluginID)
systemExists, err := afero.DirExists(m.fs, systemPluginPath)
if err != nil {
return false, fmt.Errorf("failed to check system plugin: %w", err)
}
if systemExists {
return false, nil
}
return false, fmt.Errorf("plugin not installed: %s", pluginID) return false, fmt.Errorf("plugin not installed: %s", pluginID)
} }
if strings.HasPrefix(pluginPath, "/etc/xdg/quickshell/dms-plugins") { // Check if there's a .meta file (plugin installed from a monorepo)
return false, nil
}
metaPath := pluginPath + ".meta" metaPath := pluginPath + ".meta"
metaExists, err := afero.Exists(m.fs, metaPath) metaExists, err := afero.Exists(m.fs, metaPath)
if err != nil { if err != nil {

View File

@@ -3,8 +3,6 @@ package plugins
import ( import (
"sort" "sort"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
func FuzzySearch(query string, plugins []Plugin) []Plugin { func FuzzySearch(query string, plugins []Plugin) []Plugin {
@@ -13,12 +11,18 @@ func FuzzySearch(query string, plugins []Plugin) []Plugin {
} }
queryLower := strings.ToLower(query) queryLower := strings.ToLower(query)
return utils.Filter(plugins, func(p Plugin) bool { var results []Plugin
return fuzzyMatch(queryLower, strings.ToLower(p.Name)) ||
fuzzyMatch(queryLower, strings.ToLower(p.Category)) || for _, plugin := range plugins {
fuzzyMatch(queryLower, strings.ToLower(p.Description)) || if fuzzyMatch(queryLower, strings.ToLower(plugin.Name)) ||
fuzzyMatch(queryLower, strings.ToLower(p.Author)) fuzzyMatch(queryLower, strings.ToLower(plugin.Category)) ||
}) fuzzyMatch(queryLower, strings.ToLower(plugin.Description)) ||
fuzzyMatch(queryLower, strings.ToLower(plugin.Author)) {
results = append(results, plugin)
}
}
return results
} }
func fuzzyMatch(query, text string) bool { func fuzzyMatch(query, text string) bool {
@@ -35,34 +39,57 @@ func FilterByCategory(category string, plugins []Plugin) []Plugin {
if category == "" { if category == "" {
return plugins return plugins
} }
var results []Plugin
categoryLower := strings.ToLower(category) categoryLower := strings.ToLower(category)
return utils.Filter(plugins, func(p Plugin) bool {
return strings.ToLower(p.Category) == categoryLower for _, plugin := range plugins {
}) if strings.ToLower(plugin.Category) == categoryLower {
results = append(results, plugin)
}
}
return results
} }
func FilterByCompositor(compositor string, plugins []Plugin) []Plugin { func FilterByCompositor(compositor string, plugins []Plugin) []Plugin {
if compositor == "" { if compositor == "" {
return plugins return plugins
} }
var results []Plugin
compositorLower := strings.ToLower(compositor) compositorLower := strings.ToLower(compositor)
return utils.Filter(plugins, func(p Plugin) bool {
return utils.Any(p.Compositors, func(c string) bool { for _, plugin := range plugins {
return strings.ToLower(c) == compositorLower for _, comp := range plugin.Compositors {
}) if strings.ToLower(comp) == compositorLower {
}) results = append(results, plugin)
break
}
}
}
return results
} }
func FilterByCapability(capability string, plugins []Plugin) []Plugin { func FilterByCapability(capability string, plugins []Plugin) []Plugin {
if capability == "" { if capability == "" {
return plugins return plugins
} }
var results []Plugin
capabilityLower := strings.ToLower(capability) capabilityLower := strings.ToLower(capability)
return utils.Filter(plugins, func(p Plugin) bool {
return utils.Any(p.Capabilities, func(c string) bool { for _, plugin := range plugins {
return strings.ToLower(c) == capabilityLower for _, cap := range plugin.Capabilities {
}) if strings.ToLower(cap) == capabilityLower {
}) results = append(results, plugin)
break
}
}
}
return results
} }
func SortByFirstParty(plugins []Plugin) []Plugin { func SortByFirstParty(plugins []Plugin) []Plugin {
@@ -76,13 +103,3 @@ func SortByFirstParty(plugins []Plugin) []Plugin {
}) })
return plugins return plugins
} }
func FindByIDOrName(idOrName string, plugins []Plugin) *Plugin {
if p, found := utils.Find(plugins, func(p Plugin) bool { return p.ID == idOrName }); found {
return &p
}
if p, found := utils.Find(plugins, func(p Plugin) bool { return p.Name == idOrName }); found {
return &p
}
return nil
}

View File

@@ -238,17 +238,9 @@ func (i *ZwlrOutputManagerV1) Dispatch(opcode uint32, fd int, data []byte) {
l := 0 l := 0
objectID := client.Uint32(data[l : l+4]) objectID := client.Uint32(data[l : l+4])
proxy := i.Context().GetProxy(objectID) proxy := i.Context().GetProxy(objectID)
if proxy == nil { if proxy != nil {
head := &ZwlrOutputHeadV1{} e.Head = proxy.(*ZwlrOutputHeadV1)
head.SetContext(i.Context())
head.SetID(objectID)
registerServerProxy(i.Context(), head, objectID)
e.Head = head
} else if head, ok := proxy.(*ZwlrOutputHeadV1); ok {
e.Head = head
} else { } else {
// Stale proxy of wrong type (can happen after suspend/resume)
// Replace it with the correct type
head := &ZwlrOutputHeadV1{} head := &ZwlrOutputHeadV1{}
head.SetContext(i.Context()) head.SetContext(i.Context())
head.SetID(objectID) head.SetID(objectID)
@@ -723,17 +715,9 @@ func (i *ZwlrOutputHeadV1) Dispatch(opcode uint32, fd int, data []byte) {
l := 0 l := 0
objectID := client.Uint32(data[l : l+4]) objectID := client.Uint32(data[l : l+4])
proxy := i.Context().GetProxy(objectID) proxy := i.Context().GetProxy(objectID)
if proxy == nil { if proxy != nil {
mode := &ZwlrOutputModeV1{} e.Mode = proxy.(*ZwlrOutputModeV1)
mode.SetContext(i.Context())
mode.SetID(objectID)
registerServerProxy(i.Context(), mode, objectID)
e.Mode = mode
} else if mode, ok := proxy.(*ZwlrOutputModeV1); ok {
e.Mode = mode
} else { } else {
// Stale proxy of wrong type (can happen after suspend/resume)
// Replace it with the correct type
mode := &ZwlrOutputModeV1{} mode := &ZwlrOutputModeV1{}
mode.SetContext(i.Context()) mode.SetContext(i.Context())
mode.SetID(objectID) mode.SetID(objectID)
@@ -759,26 +743,7 @@ func (i *ZwlrOutputHeadV1) Dispatch(opcode uint32, fd int, data []byte) {
} }
var e ZwlrOutputHeadV1CurrentModeEvent var e ZwlrOutputHeadV1CurrentModeEvent
l := 0 l := 0
objectID := client.Uint32(data[l : l+4]) e.Mode = i.Context().GetProxy(client.Uint32(data[l : l+4])).(*ZwlrOutputModeV1)
proxy := i.Context().GetProxy(objectID)
if proxy == nil {
// Mode not yet registered, create it
mode := &ZwlrOutputModeV1{}
mode.SetContext(i.Context())
mode.SetID(objectID)
registerServerProxy(i.Context(), mode, objectID)
e.Mode = mode
} else if mode, ok := proxy.(*ZwlrOutputModeV1); ok {
e.Mode = mode
} else {
// Stale proxy of wrong type (can happen after suspend/resume)
// Replace it with the correct type
mode := &ZwlrOutputModeV1{}
mode.SetContext(i.Context())
mode.SetID(objectID)
registerServerProxy(i.Context(), mode, objectID)
e.Mode = mode
}
l += 4 l += 4
i.currentModeHandler(e) i.currentModeHandler(e)

View File

@@ -7,7 +7,6 @@ import (
"os/exec" "os/exec"
"github.com/AvengeMedia/DankMaterialShell/core/internal/proto/dwl_ipc" "github.com/AvengeMedia/DankMaterialShell/core/internal/proto/dwl_ipc"
"github.com/AvengeMedia/DankMaterialShell/core/internal/proto/wlr_output_management"
wlhelpers "github.com/AvengeMedia/DankMaterialShell/core/internal/wayland/client" wlhelpers "github.com/AvengeMedia/DankMaterialShell/core/internal/wayland/client"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/go-wayland/wayland/client" "github.com/AvengeMedia/DankMaterialShell/core/pkg/go-wayland/wayland/client"
) )
@@ -20,7 +19,6 @@ const (
CompositorSway CompositorSway
CompositorNiri CompositorNiri
CompositorDWL CompositorDWL
CompositorScroll
) )
var detectedCompositor Compositor = -1 var detectedCompositor Compositor = -1
@@ -33,7 +31,6 @@ func DetectCompositor() Compositor {
hyprlandSig := os.Getenv("HYPRLAND_INSTANCE_SIGNATURE") hyprlandSig := os.Getenv("HYPRLAND_INSTANCE_SIGNATURE")
niriSocket := os.Getenv("NIRI_SOCKET") niriSocket := os.Getenv("NIRI_SOCKET")
swaySocket := os.Getenv("SWAYSOCK") swaySocket := os.Getenv("SWAYSOCK")
scrollSocket := os.Getenv("SCROLLSOCK")
switch { switch {
case niriSocket != "": case niriSocket != "":
@@ -41,12 +38,6 @@ func DetectCompositor() Compositor {
detectedCompositor = CompositorNiri detectedCompositor = CompositorNiri
return detectedCompositor return detectedCompositor
} }
case scrollSocket != "":
if _, err := os.Stat(scrollSocket); err == nil {
detectedCompositor = CompositorScroll
return detectedCompositor
}
case swaySocket != "": case swaySocket != "":
if _, err := os.Stat(swaySocket); err == nil { if _, err := os.Stat(swaySocket); err == nil {
detectedCompositor = CompositorSway detectedCompositor = CompositorSway
@@ -98,15 +89,12 @@ func SetCompositorDWL() {
} }
type WindowGeometry struct { type WindowGeometry struct {
X int32 X int32
Y int32 Y int32
Width int32 Width int32
Height int32 Height int32
Output string Output string
Scale float64 Scale float64
OutputX int32
OutputY int32
OutputTransform int32
} }
func GetActiveWindow() (*WindowGeometry, error) { func GetActiveWindow() (*WindowGeometry, error) {
@@ -241,25 +229,6 @@ func getSwayFocusedMonitor() string {
return "" return ""
} }
func getScrollFocusedMonitor() string {
output, err := exec.Command("scrollmsg", "-t", "get_workspaces").Output()
if err != nil {
return ""
}
var workspaces []swayWorkspace
if err := json.Unmarshal(output, &workspaces); err != nil {
return ""
}
for _, ws := range workspaces {
if ws.Focused {
return ws.Output
}
}
return ""
}
type niriWorkspace struct { type niriWorkspace struct {
Output string `json:"output"` Output string `json:"output"`
IsFocused bool `json:"is_focused"` IsFocused bool `json:"is_focused"`
@@ -405,8 +374,6 @@ func GetFocusedMonitor() string {
return getHyprlandFocusedMonitor() return getHyprlandFocusedMonitor()
case CompositorSway: case CompositorSway:
return getSwayFocusedMonitor() return getSwayFocusedMonitor()
case CompositorScroll:
return getScrollFocusedMonitor()
case CompositorNiri: case CompositorNiri:
return getNiriFocusedMonitor() return getNiriFocusedMonitor()
case CompositorDWL: case CompositorDWL:
@@ -415,92 +382,6 @@ func GetFocusedMonitor() string {
return "" return ""
} }
type outputInfo struct {
x, y int32
transform int32
}
func getOutputInfo(outputName string) (*outputInfo, bool) {
display, err := client.Connect("")
if err != nil {
return nil, false
}
ctx := display.Context()
defer ctx.Close()
registry, err := display.GetRegistry()
if err != nil {
return nil, false
}
var outputManager *wlr_output_management.ZwlrOutputManagerV1
registry.SetGlobalHandler(func(e client.RegistryGlobalEvent) {
if e.Interface == wlr_output_management.ZwlrOutputManagerV1InterfaceName {
mgr := wlr_output_management.NewZwlrOutputManagerV1(ctx)
version := e.Version
if version > 4 {
version = 4
}
if err := registry.Bind(e.Name, e.Interface, version, mgr); err == nil {
outputManager = mgr
}
}
})
if err := wlhelpers.Roundtrip(display, ctx); err != nil {
return nil, false
}
if outputManager == nil {
return nil, false
}
type headState struct {
name string
x, y int32
transform int32
}
heads := make(map[*wlr_output_management.ZwlrOutputHeadV1]*headState)
done := false
outputManager.SetHeadHandler(func(e wlr_output_management.ZwlrOutputManagerV1HeadEvent) {
state := &headState{}
heads[e.Head] = state
e.Head.SetNameHandler(func(ne wlr_output_management.ZwlrOutputHeadV1NameEvent) {
state.name = ne.Name
})
e.Head.SetPositionHandler(func(pe wlr_output_management.ZwlrOutputHeadV1PositionEvent) {
state.x = pe.X
state.y = pe.Y
})
e.Head.SetTransformHandler(func(te wlr_output_management.ZwlrOutputHeadV1TransformEvent) {
state.transform = te.Transform
})
})
outputManager.SetDoneHandler(func(e wlr_output_management.ZwlrOutputManagerV1DoneEvent) {
done = true
})
for !done {
if err := ctx.Dispatch(); err != nil {
return nil, false
}
}
for _, state := range heads {
if state.name == outputName {
return &outputInfo{
x: state.x,
y: state.y,
transform: state.transform,
}, true
}
}
return nil, false
}
func getDWLActiveWindow() (*WindowGeometry, error) { func getDWLActiveWindow() (*WindowGeometry, error) {
display, err := client.Connect("") display, err := client.Connect("")
if err != nil { if err != nil {
@@ -628,23 +509,14 @@ func getDWLActiveWindow() (*WindowGeometry, error) {
if scale <= 0 { if scale <= 0 {
scale = 1.0 scale = 1.0
} }
return &WindowGeometry{
geom := &WindowGeometry{
X: state.x, X: state.x,
Y: state.y, Y: state.y,
Width: state.w, Width: state.w,
Height: state.h, Height: state.h,
Output: state.name, Output: state.name,
Scale: scale, Scale: scale,
} }, nil
if info, ok := getOutputInfo(state.name); ok {
geom.OutputX = info.x
geom.OutputY = info.y
geom.OutputTransform = info.transform
}
return geom, nil
} }
return nil, fmt.Errorf("no active output found") return nil, fmt.Errorf("no active output found")

View File

@@ -9,10 +9,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
func BufferToImage(buf *ShmBuffer) *image.RGBA { func BufferToImage(buf *ShmBuffer) *image.RGBA {
@@ -23,13 +20,7 @@ func BufferToImageWithFormat(buf *ShmBuffer, format uint32) *image.RGBA {
img := image.NewRGBA(image.Rect(0, 0, buf.Width, buf.Height)) img := image.NewRGBA(image.Rect(0, 0, buf.Width, buf.Height))
data := buf.Data() data := buf.Data()
var swapRB bool swapRB := format == uint32(FormatARGB8888) || format == uint32(FormatXRGB8888) || format == 0
switch format {
case uint32(FormatABGR8888), uint32(FormatXBGR8888):
swapRB = false
default:
swapRB = true
}
for y := 0; y < buf.Height; y++ { for y := 0; y < buf.Height; y++ {
srcOff := y * buf.Stride srcOff := y * buf.Stride
@@ -119,30 +110,70 @@ func GetOutputDir() string {
} }
func getXDGPicturesDir() string { func getXDGPicturesDir() string {
userDirsFile := filepath.Join(utils.XDGConfigHome(), "user-dirs.dirs") configDir := os.Getenv("XDG_CONFIG_HOME")
if configDir == "" {
home := os.Getenv("HOME")
if home == "" {
return ""
}
configDir = filepath.Join(home, ".config")
}
userDirsFile := filepath.Join(configDir, "user-dirs.dirs")
data, err := os.ReadFile(userDirsFile) data, err := os.ReadFile(userDirsFile)
if err != nil { if err != nil {
return "" return ""
} }
for _, line := range strings.Split(string(data), "\n") { for _, line := range splitLines(string(data)) {
if len(line) == 0 || line[0] == '#' { if len(line) == 0 || line[0] == '#' {
continue continue
} }
const prefix = "XDG_PICTURES_DIR=" const prefix = "XDG_PICTURES_DIR="
if !strings.HasPrefix(line, prefix) { if len(line) > len(prefix) && line[:len(prefix)] == prefix {
continue path := line[len(prefix):]
path = trimQuotes(path)
path = expandHome(path)
return path
} }
path := strings.Trim(line[len(prefix):], "\"")
expanded, err := utils.ExpandPath(path)
if err != nil {
return ""
}
return expanded
} }
return "" return ""
} }
func splitLines(s string) []string {
var lines []string
start := 0
for i := 0; i < len(s); i++ {
if s[i] == '\n' {
lines = append(lines, s[start:i])
start = i + 1
}
}
if start < len(s) {
lines = append(lines, s[start:])
}
return lines
}
func trimQuotes(s string) string {
if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' {
return s[1 : len(s)-1]
}
return s
}
func expandHome(path string) string {
if len(path) >= 5 && path[:5] == "$HOME" {
home := os.Getenv("HOME")
return home + path[5:]
}
if len(path) >= 1 && path[0] == '~' {
home := os.Getenv("HOME")
return home + path[1:]
}
return path
}
func WriteToFile(buf *ShmBuffer, path string, format Format, quality int) error { func WriteToFile(buf *ShmBuffer, path string, format Format, quality int) error {
return WriteToFileWithFormat(buf, path, format, quality, uint32(FormatARGB8888)) return WriteToFileWithFormat(buf, path, format, quality, uint32(FormatARGB8888))
} }

View File

@@ -380,24 +380,19 @@ func (r *RegionSelector) preCaptureOutput(output *WaylandOutput, pc *PreCapture,
return return
} }
var capturedBuf *ShmBuffer
var capturedFormat PixelFormat
frame.SetBufferHandler(func(e wlr_screencopy.ZwlrScreencopyFrameV1BufferEvent) { frame.SetBufferHandler(func(e wlr_screencopy.ZwlrScreencopyFrameV1BufferEvent) {
capturedFormat = PixelFormat(e.Format)
bpp := capturedFormat.BytesPerPixel()
if int(e.Stride) < int(e.Width)*bpp {
log.Error("invalid stride from compositor", "stride", e.Stride, "width", e.Width, "bpp", bpp)
return
}
buf, err := CreateShmBuffer(int(e.Width), int(e.Height), int(e.Stride)) buf, err := CreateShmBuffer(int(e.Width), int(e.Height), int(e.Stride))
if err != nil { if err != nil {
log.Error("create screen buffer failed", "err", err) log.Error("create screen buffer failed", "err", err)
return return
} }
capturedBuf = buf if withCursor {
buf.Format = capturedFormat pc.screenBuf = buf
pc.format = e.Format
} else {
pc.screenBufNoCursor = buf
}
pool, err := r.shm.CreatePool(buf.Fd(), int32(buf.Size())) pool, err := r.shm.CreatePool(buf.Fd(), int32(buf.Size()))
if err != nil { if err != nil {
@@ -426,47 +421,6 @@ func (r *RegionSelector) preCaptureOutput(output *WaylandOutput, pc *PreCapture,
frame.SetReadyHandler(func(e wlr_screencopy.ZwlrScreencopyFrameV1ReadyEvent) { frame.SetReadyHandler(func(e wlr_screencopy.ZwlrScreencopyFrameV1ReadyEvent) {
frame.Destroy() frame.Destroy()
if capturedBuf == nil {
onReady()
return
}
if capturedFormat.Is24Bit() {
converted, newFormat, err := capturedBuf.ConvertTo32Bit(capturedFormat)
if err != nil {
log.Error("convert 24-bit to 32-bit failed", "err", err)
} else if converted != capturedBuf {
capturedBuf.Close()
capturedBuf = converted
capturedFormat = newFormat
}
}
pc.format = uint32(capturedFormat)
if pc.yInverted {
capturedBuf.FlipVertical()
pc.yInverted = false
}
if output.transform != TransformNormal {
invTransform := InverseTransform(output.transform)
transformed, err := capturedBuf.ApplyTransform(invTransform)
if err != nil {
log.Error("apply transform failed", "err", err)
} else if transformed != capturedBuf {
capturedBuf.Close()
capturedBuf = transformed
}
}
if withCursor {
pc.screenBuf = capturedBuf
} else {
pc.screenBufNoCursor = capturedBuf
}
onReady() onReady()
}) })

View File

@@ -150,33 +150,51 @@ func (s *Screenshoter) captureWindow() (*CaptureResult, error) {
case CompositorHyprland: case CompositorHyprland:
return s.captureAndCrop(output, region) return s.captureAndCrop(output, region)
case CompositorDWL: case CompositorDWL:
return s.captureDWLWindow(output, region, geom) return s.captureDWLWindow(output, region, geom.Scale)
default: default:
return s.captureRegionOnOutput(output, region) return s.captureRegionOnOutput(output, region)
} }
} }
func (s *Screenshoter) captureDWLWindow(output *WaylandOutput, region Region, geom *WindowGeometry) (*CaptureResult, error) { func (s *Screenshoter) captureDWLWindow(output *WaylandOutput, region Region, dwlScale float64) (*CaptureResult, error) {
result, err := s.captureWholeOutput(output) result, err := s.captureWholeOutput(output)
if err != nil { if err != nil {
return nil, err return nil, err
} }
scale := geom.Scale scale := dwlScale
if scale <= 0 || scale == 1.0 { if scale <= 0 {
if output.fractionalScale > 1.0 { scale = float64(result.Buffer.Width) / float64(output.width)
scale = output.fractionalScale
}
} }
if scale <= 0 { if scale <= 0 {
scale = 1.0 scale = 1.0
} }
localX := int(float64(region.X-geom.OutputX) * scale) localX := int(float64(region.X) * scale)
localY := int(float64(region.Y-geom.OutputY) * scale) localY := int(float64(region.Y) * scale)
if localX >= result.Buffer.Width {
localX = localX % result.Buffer.Width
}
if localY >= result.Buffer.Height {
localY = localY % result.Buffer.Height
}
w := int(float64(region.Width) * scale) w := int(float64(region.Width) * scale)
h := int(float64(region.Height) * scale) h := int(float64(region.Height) * scale)
if localY+h > result.Buffer.Height && h <= result.Buffer.Height {
localY = result.Buffer.Height - h
if localY < 0 {
localY = 0
}
}
if localX+w > result.Buffer.Width && w <= result.Buffer.Width {
localX = result.Buffer.Width - w
if localX < 0 {
localX = 0
}
}
if localX < 0 { if localX < 0 {
w += localX w += localX
localX = 0 localX = 0
@@ -324,18 +342,13 @@ func (s *Screenshoter) captureAllScreens() (*CaptureResult, error) {
outX, outY := output.x, output.y outX, outY := output.x, output.y
scale := float64(output.scale) scale := float64(output.scale)
switch DetectCompositor() { if DetectCompositor() == CompositorHyprland {
case CompositorHyprland:
if hx, hy, _, _, ok := GetHyprlandMonitorGeometry(output.name); ok { if hx, hy, _, _, ok := GetHyprlandMonitorGeometry(output.name); ok {
outX, outY = hx, hy outX, outY = hx, hy
} }
if s := GetHyprlandMonitorScale(output.name); s > 0 { if s := GetHyprlandMonitorScale(output.name); s > 0 {
scale = s scale = s
} }
case CompositorDWL:
if info, ok := getOutputInfo(output.name); ok {
outX, outY = info.x, info.y
}
} }
if scale <= 0 { if scale <= 0 {
scale = 1.0 scale = 1.0
@@ -463,42 +476,13 @@ func (s *Screenshoter) captureWholeOutput(output *WaylandOutput) (*CaptureResult
return nil, fmt.Errorf("capture output: %w", err) return nil, fmt.Errorf("capture output: %w", err)
} }
result, err := s.processFrame(frame, Region{ return s.processFrame(frame, Region{
X: output.x, X: output.x,
Y: output.y, Y: output.y,
Width: output.width, Width: output.width,
Height: output.height, Height: output.height,
Output: output.name, Output: output.name,
}) })
if err != nil {
return nil, err
}
if result.YInverted {
result.Buffer.FlipVertical()
result.YInverted = false
}
if output.transform == TransformNormal {
return result, nil
}
invTransform := InverseTransform(output.transform)
transformed, err := result.Buffer.ApplyTransform(invTransform)
if err != nil {
result.Buffer.Close()
return nil, fmt.Errorf("apply transform: %w", err)
}
if transformed != result.Buffer {
result.Buffer.Close()
result.Buffer = transformed
}
result.Region.Width = int32(transformed.Width)
result.Region.Height = int32(transformed.Height)
return result, nil
} }
func (s *Screenshoter) captureAndCrop(output *WaylandOutput, region Region) (*CaptureResult, error) { func (s *Screenshoter) captureAndCrop(output *WaylandOutput, region Region) (*CaptureResult, error) {
@@ -579,10 +563,6 @@ func (s *Screenshoter) captureAndCrop(output *WaylandOutput, region Region) (*Ca
} }
func (s *Screenshoter) captureRegionOnOutput(output *WaylandOutput, region Region) (*CaptureResult, error) { func (s *Screenshoter) captureRegionOnOutput(output *WaylandOutput, region Region) (*CaptureResult, error) {
if output.transform != TransformNormal {
return s.captureRegionOnTransformedOutput(output, region)
}
scale := output.fractionalScale scale := output.fractionalScale
if scale <= 0 && DetectCompositor() == CompositorHyprland { if scale <= 0 && DetectCompositor() == CompositorHyprland {
scale = GetHyprlandMonitorScale(output.name) scale = GetHyprlandMonitorScale(output.name)
@@ -637,76 +617,6 @@ func (s *Screenshoter) captureRegionOnOutput(output *WaylandOutput, region Regio
return s.processFrame(frame, region) return s.processFrame(frame, region)
} }
func (s *Screenshoter) captureRegionOnTransformedOutput(output *WaylandOutput, region Region) (*CaptureResult, error) {
result, err := s.captureWholeOutput(output)
if err != nil {
return nil, err
}
scale := output.fractionalScale
if scale <= 0 && DetectCompositor() == CompositorHyprland {
scale = GetHyprlandMonitorScale(output.name)
}
if scale <= 0 {
scale = float64(output.scale)
}
if scale <= 0 {
scale = 1.0
}
localX := int(float64(region.X-output.x) * scale)
localY := int(float64(region.Y-output.y) * scale)
w := int(float64(region.Width) * scale)
h := int(float64(region.Height) * scale)
if localX < 0 {
w += localX
localX = 0
}
if localY < 0 {
h += localY
localY = 0
}
if localX+w > result.Buffer.Width {
w = result.Buffer.Width - localX
}
if localY+h > result.Buffer.Height {
h = result.Buffer.Height - localY
}
if w <= 0 || h <= 0 {
result.Buffer.Close()
return nil, fmt.Errorf("region not visible on output")
}
cropped, err := CreateShmBuffer(w, h, w*4)
if err != nil {
result.Buffer.Close()
return nil, fmt.Errorf("create crop buffer: %w", err)
}
srcData := result.Buffer.Data()
dstData := cropped.Data()
for y := 0; y < h; y++ {
srcOff := (localY+y)*result.Buffer.Stride + localX*4
dstOff := y * cropped.Stride
if srcOff+w*4 <= len(srcData) && dstOff+w*4 <= len(dstData) {
copy(dstData[dstOff:dstOff+w*4], srcData[srcOff:srcOff+w*4])
}
}
result.Buffer.Close()
cropped.Format = PixelFormat(result.Format)
return &CaptureResult{
Buffer: cropped,
Region: region,
YInverted: false,
Format: result.Format,
}, nil
}
func (s *Screenshoter) processFrame(frame *wlr_screencopy.ZwlrScreencopyFrameV1, region Region) (*CaptureResult, error) { func (s *Screenshoter) processFrame(frame *wlr_screencopy.ZwlrScreencopyFrameV1, region Region) (*CaptureResult, error) {
var buf *ShmBuffer var buf *ShmBuffer
var pool *client.ShmPool var pool *client.ShmPool
@@ -717,18 +627,13 @@ func (s *Screenshoter) processFrame(frame *wlr_screencopy.ZwlrScreencopyFrameV1,
failed := false failed := false
frame.SetBufferHandler(func(e wlr_screencopy.ZwlrScreencopyFrameV1BufferEvent) { frame.SetBufferHandler(func(e wlr_screencopy.ZwlrScreencopyFrameV1BufferEvent) {
format = PixelFormat(e.Format)
bpp := format.BytesPerPixel()
if int(e.Stride) < int(e.Width)*bpp {
log.Error("invalid stride from compositor", "stride", e.Stride, "width", e.Width, "bpp", bpp)
return
}
var err error var err error
buf, err = CreateShmBuffer(int(e.Width), int(e.Height), int(e.Stride)) buf, err = CreateShmBuffer(int(e.Width), int(e.Height), int(e.Stride))
if err != nil { if err != nil {
log.Error("failed to create buffer", "err", err) log.Error("failed to create buffer", "err", err)
return return
} }
format = PixelFormat(e.Format)
buf.Format = format buf.Format = format
}) })
@@ -791,19 +696,6 @@ func (s *Screenshoter) processFrame(frame *wlr_screencopy.ZwlrScreencopyFrameV1,
return nil, fmt.Errorf("frame capture failed") return nil, fmt.Errorf("frame capture failed")
} }
if format.Is24Bit() {
converted, newFormat, err := buf.ConvertTo32Bit(format)
if err != nil {
buf.Close()
return nil, fmt.Errorf("convert 24-bit to 32-bit: %w", err)
}
if converted != buf {
buf.Close()
buf = converted
}
format = newFormat
}
return &CaptureResult{ return &CaptureResult{
Buffer: buf, Buffer: buf,
Region: region, Region: region,
@@ -1032,32 +924,16 @@ func ListOutputs() ([]Output, error) {
sc.outputsMu.Lock() sc.outputsMu.Lock()
defer sc.outputsMu.Unlock() defer sc.outputsMu.Unlock()
compositor := DetectCompositor()
result := make([]Output, 0, len(sc.outputs)) result := make([]Output, 0, len(sc.outputs))
for _, o := range sc.outputs { for _, o := range sc.outputs {
out := Output{ result = append(result, Output{
Name: o.name, Name: o.name,
X: o.x, X: o.x,
Y: o.y, Y: o.y,
Width: o.width, Width: o.width,
Height: o.height, Height: o.height,
Scale: o.scale, Scale: o.scale,
FractionalScale: o.fractionalScale, })
Transform: o.transform,
}
switch compositor {
case CompositorHyprland:
if hx, hy, hw, hh, ok := GetHyprlandMonitorGeometry(o.name); ok {
out.X, out.Y = hx, hy
out.Width, out.Height = hw, hh
}
if s := GetHyprlandMonitorScale(o.name); s > 0 {
out.FractionalScale = s
}
}
result = append(result, out)
} }
return result, nil return result, nil
} }

View File

@@ -9,19 +9,6 @@ const (
FormatXRGB8888 = shm.FormatXRGB8888 FormatXRGB8888 = shm.FormatXRGB8888
FormatABGR8888 = shm.FormatABGR8888 FormatABGR8888 = shm.FormatABGR8888
FormatXBGR8888 = shm.FormatXBGR8888 FormatXBGR8888 = shm.FormatXBGR8888
FormatRGB888 = shm.FormatRGB888
FormatBGR888 = shm.FormatBGR888
)
const (
TransformNormal = shm.TransformNormal
Transform90 = shm.Transform90
Transform180 = shm.Transform180
Transform270 = shm.Transform270
TransformFlipped = shm.TransformFlipped
TransformFlipped90 = shm.TransformFlipped90
TransformFlipped180 = shm.TransformFlipped180
TransformFlipped270 = shm.TransformFlipped270
) )
type ShmBuffer = shm.Buffer type ShmBuffer = shm.Buffer
@@ -29,7 +16,3 @@ type ShmBuffer = shm.Buffer
func CreateShmBuffer(width, height, stride int) (*ShmBuffer, error) { func CreateShmBuffer(width, height, stride int) (*ShmBuffer, error) {
return shm.CreateBuffer(width, height, stride) return shm.CreateBuffer(width, height, stride)
} }
func InverseTransform(transform int32) int32 {
return shm.InverseTransform(transform)
}

View File

@@ -6,8 +6,6 @@ import (
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
type ThemeColors struct { type ThemeColors struct {
@@ -74,7 +72,15 @@ func loadColorsFile() *ColorScheme {
} }
func getColorsFilePath() string { func getColorsFilePath() string {
return filepath.Join(utils.XDGCacheHome(), "DankMaterialShell", "dms-colors.json") cacheDir := os.Getenv("XDG_CACHE_HOME")
if cacheDir == "" {
home := os.Getenv("HOME")
if home == "" {
return ""
}
cacheDir = filepath.Join(home, ".cache")
}
return filepath.Join(cacheDir, "DankMaterialShell", "dms-colors.json")
} }
func isLightMode() bool { func isLightMode() bool {

View File

@@ -32,13 +32,11 @@ func (r Region) IsEmpty() bool {
} }
type Output struct { type Output struct {
Name string Name string
X, Y int32 X, Y int32
Width int32 Width int32
Height int32 Height int32
Scale int32 Scale int32
FractionalScale float64
Transform int32
} }
type Config struct { type Config struct {

View File

@@ -7,7 +7,13 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
) )
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) { type Request struct {
ID int `json:"id"`
Method string `json:"method"`
Params map[string]any `json:"params"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method { switch req.Method {
case "apppicker.open", "browser.open": case "apppicker.open", "browser.open":
handleOpen(conn, req, manager) handleOpen(conn, req, manager)
@@ -16,7 +22,7 @@ func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleOpen(conn net.Conn, req models.Request, manager *Manager) { func handleOpen(conn net.Conn, req Request, manager *Manager) {
log.Infof("AppPicker: Received %s request with params: %+v", req.Method, req.Params) log.Infof("AppPicker: Received %s request with params: %+v", req.Method, req.Params)
target, ok := req.Params["target"].(string) target, ok := req.Params["target"].(string)

View File

@@ -6,15 +6,25 @@ import (
"net" "net"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/params"
) )
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]any `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
type BluetoothEvent struct { type BluetoothEvent struct {
Type string `json:"type"` Type string `json:"type"`
Data BluetoothState `json:"data"` Data BluetoothState `json:"data"`
} }
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) { func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method { switch req.Method {
case "bluetooth.getState": case "bluetooth.getState":
handleGetState(conn, req, manager) handleGetState(conn, req, manager)
@@ -47,30 +57,31 @@ func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleGetState(conn net.Conn, req models.Request, manager *Manager) { func handleGetState(conn net.Conn, req Request, manager *Manager) {
models.Respond(conn, req.ID, manager.GetState()) state := manager.GetState()
models.Respond(conn, req.ID, state)
} }
func handleStartDiscovery(conn net.Conn, req models.Request, manager *Manager) { func handleStartDiscovery(conn net.Conn, req Request, manager *Manager) {
if err := manager.StartDiscovery(); err != nil { if err := manager.StartDiscovery(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "discovery started"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "discovery started"})
} }
func handleStopDiscovery(conn net.Conn, req models.Request, manager *Manager) { func handleStopDiscovery(conn net.Conn, req Request, manager *Manager) {
if err := manager.StopDiscovery(); err != nil { if err := manager.StopDiscovery(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "discovery stopped"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "discovery stopped"})
} }
func handleSetPowered(conn net.Conn, req models.Request, manager *Manager) { func handleSetPowered(conn net.Conn, req Request, manager *Manager) {
powered, err := params.Bool(req.Params, "powered") powered, ok := req.Params["powered"].(bool)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'powered' parameter")
return return
} }
@@ -79,13 +90,13 @@ func handleSetPowered(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "powered state updated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "powered state updated"})
} }
func handlePairDevice(conn net.Conn, req models.Request, manager *Manager) { func handlePairDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, err := params.String(req.Params, "device") devicePath, ok := req.Params["device"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return return
} }
@@ -94,13 +105,13 @@ func handlePairDevice(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "pairing initiated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "pairing initiated"})
} }
func handleConnectDevice(conn net.Conn, req models.Request, manager *Manager) { func handleConnectDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, err := params.String(req.Params, "device") devicePath, ok := req.Params["device"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return return
} }
@@ -109,13 +120,13 @@ func handleConnectDevice(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "connecting"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "connecting"})
} }
func handleDisconnectDevice(conn net.Conn, req models.Request, manager *Manager) { func handleDisconnectDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, err := params.String(req.Params, "device") devicePath, ok := req.Params["device"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return return
} }
@@ -124,13 +135,13 @@ func handleDisconnectDevice(conn net.Conn, req models.Request, manager *Manager)
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "disconnected"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "disconnected"})
} }
func handleRemoveDevice(conn net.Conn, req models.Request, manager *Manager) { func handleRemoveDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, err := params.String(req.Params, "device") devicePath, ok := req.Params["device"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return return
} }
@@ -139,13 +150,13 @@ func handleRemoveDevice(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "device removed"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "device removed"})
} }
func handleTrustDevice(conn net.Conn, req models.Request, manager *Manager) { func handleTrustDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, err := params.String(req.Params, "device") devicePath, ok := req.Params["device"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return return
} }
@@ -154,13 +165,13 @@ func handleTrustDevice(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "device trusted"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "device trusted"})
} }
func handleUntrustDevice(conn net.Conn, req models.Request, manager *Manager) { func handleUntrustDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, err := params.String(req.Params, "device") devicePath, ok := req.Params["device"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return return
} }
@@ -169,31 +180,43 @@ func handleUntrustDevice(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "device untrusted"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "device untrusted"})
} }
func handlePairingSubmit(conn net.Conn, req models.Request, manager *Manager) { func handlePairingSubmit(conn net.Conn, req Request, manager *Manager) {
token, err := params.String(req.Params, "token") token, ok := req.Params["token"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'token' parameter")
return return
} }
secrets := params.StringMapOpt(req.Params, "secrets") secretsRaw, ok := req.Params["secrets"].(map[string]any)
accept := params.BoolOpt(req.Params, "accept", false) secrets := make(map[string]string)
if ok {
for k, v := range secretsRaw {
if str, ok := v.(string); ok {
secrets[k] = str
}
}
}
accept := false
if acceptParam, ok := req.Params["accept"].(bool); ok {
accept = acceptParam
}
if err := manager.SubmitPairing(token, secrets, accept); err != nil { if err := manager.SubmitPairing(token, secrets, accept); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "pairing response submitted"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "pairing response submitted"})
} }
func handlePairingCancel(conn net.Conn, req models.Request, manager *Manager) { func handlePairingCancel(conn net.Conn, req Request, manager *Manager) {
token, err := params.String(req.Params, "token") token, ok := req.Params["token"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'token' parameter")
return return
} }
@@ -202,10 +225,10 @@ func handlePairingCancel(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "pairing cancelled"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "pairing cancelled"})
} }
func handleSubscribe(conn net.Conn, req models.Request, manager *Manager) { func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn) clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID) stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID) defer manager.Unsubscribe(clientID)

View File

@@ -2,14 +2,12 @@ package brightness
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/params"
) )
func HandleRequest(conn net.Conn, req models.Request, m *Manager) { func HandleRequest(conn net.Conn, req Request, m *Manager) {
switch req.Method { switch req.Method {
case "brightness.getState": case "brightness.getState":
handleGetState(conn, req, m) handleGetState(conn, req, m)
@@ -24,90 +22,131 @@ func HandleRequest(conn net.Conn, req models.Request, m *Manager) {
case "brightness.subscribe": case "brightness.subscribe":
handleSubscribe(conn, req, m) handleSubscribe(conn, req, m)
default: default:
models.RespondError(conn, req.ID, "unknown method: "+req.Method) models.RespondError(conn, req.ID.(int), "unknown method: "+req.Method)
} }
} }
func handleGetState(conn net.Conn, req models.Request, m *Manager) { func handleGetState(conn net.Conn, req Request, m *Manager) {
models.Respond(conn, req.ID, m.GetState()) state := m.GetState()
models.Respond(conn, req.ID.(int), state)
} }
func handleSetBrightness(conn net.Conn, req models.Request, m *Manager) { func handleSetBrightness(conn net.Conn, req Request, m *Manager) {
device, err := params.String(req.Params, "device") var params SetBrightnessParams
if err != nil {
models.RespondError(conn, req.ID, err.Error()) device, ok := req.Params["device"].(string)
if !ok {
models.RespondError(conn, req.ID.(int), "missing or invalid device parameter")
return
}
params.Device = device
percentFloat, ok := req.Params["percent"].(float64)
if !ok {
models.RespondError(conn, req.ID.(int), "missing or invalid percent parameter")
return
}
params.Percent = int(percentFloat)
if exponential, ok := req.Params["exponential"].(bool); ok {
params.Exponential = exponential
}
exponent := 1.2
if exponentFloat, ok := req.Params["exponent"].(float64); ok {
params.Exponent = exponentFloat
exponent = exponentFloat
}
if err := m.SetBrightnessWithExponent(params.Device, params.Percent, params.Exponential, exponent); err != nil {
models.RespondError(conn, req.ID.(int), err.Error())
return return
} }
percent, err := params.Int(req.Params, "percent") state := m.GetState()
if err != nil { models.Respond(conn, req.ID.(int), state)
models.RespondError(conn, req.ID, err.Error())
return
}
exponential := params.BoolOpt(req.Params, "exponential", false)
exponent := params.FloatOpt(req.Params, "exponent", 1.2)
if err := m.SetBrightnessWithExponent(device, percent, exponential, exponent); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, m.GetState())
} }
func handleIncrement(conn net.Conn, req models.Request, m *Manager) { func handleIncrement(conn net.Conn, req Request, m *Manager) {
device, err := params.String(req.Params, "device") device, ok := req.Params["device"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID.(int), "missing or invalid device parameter")
return return
} }
step := params.IntOpt(req.Params, "step", 10) step := 10
exponential := params.BoolOpt(req.Params, "exponential", false) if stepFloat, ok := req.Params["step"].(float64); ok {
exponent := params.FloatOpt(req.Params, "exponent", 1.2) step = int(stepFloat)
}
exponential := false
if expBool, ok := req.Params["exponential"].(bool); ok {
exponential = expBool
}
exponent := 1.2
if exponentFloat, ok := req.Params["exponent"].(float64); ok {
exponent = exponentFloat
}
if err := m.IncrementBrightnessWithExponent(device, step, exponential, exponent); err != nil { if err := m.IncrementBrightnessWithExponent(device, step, exponential, exponent); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID.(int), err.Error())
return return
} }
models.Respond(conn, req.ID, m.GetState()) state := m.GetState()
models.Respond(conn, req.ID.(int), state)
} }
func handleDecrement(conn net.Conn, req models.Request, m *Manager) { func handleDecrement(conn net.Conn, req Request, m *Manager) {
device, err := params.String(req.Params, "device") device, ok := req.Params["device"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID.(int), "missing or invalid device parameter")
return return
} }
step := params.IntOpt(req.Params, "step", 10) step := 10
exponential := params.BoolOpt(req.Params, "exponential", false) if stepFloat, ok := req.Params["step"].(float64); ok {
exponent := params.FloatOpt(req.Params, "exponent", 1.2) step = int(stepFloat)
}
exponential := false
if expBool, ok := req.Params["exponential"].(bool); ok {
exponential = expBool
}
exponent := 1.2
if exponentFloat, ok := req.Params["exponent"].(float64); ok {
exponent = exponentFloat
}
if err := m.IncrementBrightnessWithExponent(device, -step, exponential, exponent); err != nil { if err := m.IncrementBrightnessWithExponent(device, -step, exponential, exponent); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID.(int), err.Error())
return return
} }
models.Respond(conn, req.ID, m.GetState()) state := m.GetState()
models.Respond(conn, req.ID.(int), state)
} }
func handleRescan(conn net.Conn, req models.Request, m *Manager) { func handleRescan(conn net.Conn, req Request, m *Manager) {
m.Rescan() m.Rescan()
models.Respond(conn, req.ID, m.GetState()) state := m.GetState()
models.Respond(conn, req.ID.(int), state)
} }
func handleSubscribe(conn net.Conn, req models.Request, m *Manager) { func handleSubscribe(conn net.Conn, req Request, m *Manager) {
clientID := fmt.Sprintf("brightness-%d", req.ID) clientID := "brightness-subscriber"
if idStr, ok := req.ID.(string); ok && idStr != "" {
clientID = idStr
}
ch := m.Subscribe(clientID) ch := m.Subscribe(clientID)
defer m.Unsubscribe(clientID) defer m.Unsubscribe(clientID)
initialState := m.GetState() initialState := m.GetState()
if err := json.NewEncoder(conn).Encode(models.Response[State]{ if err := json.NewEncoder(conn).Encode(models.Response[State]{
ID: req.ID, ID: req.ID.(int),
Result: &initialState, Result: &initialState,
}); err != nil { }); err != nil {
return return
@@ -115,7 +154,7 @@ func handleSubscribe(conn net.Conn, req models.Request, m *Manager) {
for state := range ch { for state := range ch {
if err := json.NewEncoder(conn).Encode(models.Response[State]{ if err := json.NewEncoder(conn).Encode(models.Response[State]{
ID: req.ID, ID: req.ID.(int),
Result: &state, Result: &state,
}); err != nil { }); err != nil {
return return

View File

@@ -33,6 +33,12 @@ type DeviceUpdate struct {
Device Device `json:"device"` Device Device `json:"device"`
} }
type Request struct {
ID any `json:"id"`
Method string `json:"method"`
Params map[string]any `json:"params"`
}
type Manager struct { type Manager struct {
logindBackend *LogindBackend logindBackend *LogindBackend
sysfsBackend *SysfsBackend sysfsBackend *SysfsBackend
@@ -106,6 +112,13 @@ type ddcCapability struct {
current int current int
} }
type SetBrightnessParams struct {
Device string `json:"device"`
Percent int `json:"percent"`
Exponential bool `json:"exponential,omitempty"`
Exponent float64 `json:"exponent,omitempty"`
}
func (m *Manager) Subscribe(id string) chan State { func (m *Manager) Subscribe(id string) chan State {
ch := make(chan State, 16) ch := make(chan State, 16)

View File

@@ -6,7 +6,13 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
) )
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) { type Request struct {
ID int `json:"id"`
Method string `json:"method"`
Params map[string]any `json:"params"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method { switch req.Method {
case "browser.open": case "browser.open":
url, ok := req.Params["url"].(string) url, ok := req.Params["url"].(string)

View File

@@ -6,21 +6,25 @@ import (
"net" "net"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/params"
) )
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]any `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
type CUPSEvent struct { type CUPSEvent struct {
Type string `json:"type"` Type string `json:"type"`
Data CUPSState `json:"data"` Data CUPSState `json:"data"`
} }
type TestPageResult struct { func HandleRequest(conn net.Conn, req Request, manager *Manager) {
Success bool `json:"success"`
JobID int `json:"jobId"`
Message string `json:"message"`
}
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
switch req.Method { switch req.Method {
case "cups.subscribe": case "cups.subscribe":
handleSubscribe(conn, req, manager) handleSubscribe(conn, req, manager)
@@ -75,19 +79,20 @@ func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleGetPrinters(conn net.Conn, req models.Request, manager *Manager) { func handleGetPrinters(conn net.Conn, req Request, manager *Manager) {
printers, err := manager.GetPrinters() printers, err := manager.GetPrinters()
if err != nil { if err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, printers) models.Respond(conn, req.ID, printers)
} }
func handleGetJobs(conn net.Conn, req models.Request, manager *Manager) { func handleGetJobs(conn net.Conn, req Request, manager *Manager) {
printerName, err := params.String(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
@@ -96,13 +101,14 @@ func handleGetJobs(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, jobs) models.Respond(conn, req.ID, jobs)
} }
func handlePausePrinter(conn net.Conn, req models.Request, manager *Manager) { func handlePausePrinter(conn net.Conn, req Request, manager *Manager) {
printerName, err := params.String(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
@@ -110,13 +116,13 @@ func handlePausePrinter(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "paused"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "paused"})
} }
func handleResumePrinter(conn net.Conn, req models.Request, manager *Manager) { func handleResumePrinter(conn net.Conn, req Request, manager *Manager) {
printerName, err := params.String(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
@@ -124,27 +130,28 @@ func handleResumePrinter(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "resumed"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "resumed"})
} }
func handleCancelJob(conn net.Conn, req models.Request, manager *Manager) { func handleCancelJob(conn net.Conn, req Request, manager *Manager) {
jobID, err := params.Int(req.Params, "jobID") jobIDFloat, ok := req.Params["jobID"].(float64)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'jobid' parameter")
return return
} }
jobID := int(jobIDFloat)
if err := manager.CancelJob(jobID); err != nil { if err := manager.CancelJob(jobID); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "job canceled"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "job canceled"})
} }
func handlePurgeJobs(conn net.Conn, req models.Request, manager *Manager) { func handlePurgeJobs(conn net.Conn, req Request, manager *Manager) {
printerName, err := params.String(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
@@ -152,10 +159,10 @@ func handlePurgeJobs(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "jobs canceled"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "jobs canceled"})
} }
func handleSubscribe(conn net.Conn, req models.Request, manager *Manager) { func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn) clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID) stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID) defer manager.Unsubscribe(clientID)
@@ -186,7 +193,7 @@ func handleSubscribe(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleGetDevices(conn net.Conn, req models.Request, manager *Manager) { func handleGetDevices(conn net.Conn, req Request, manager *Manager) {
devices, err := manager.GetDevices() devices, err := manager.GetDevices()
if err != nil { if err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
@@ -195,7 +202,7 @@ func handleGetDevices(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, devices) models.Respond(conn, req.ID, devices)
} }
func handleGetPPDs(conn net.Conn, req models.Request, manager *Manager) { func handleGetPPDs(conn net.Conn, req Request, manager *Manager) {
ppds, err := manager.GetPPDs() ppds, err := manager.GetPPDs()
if err != nil { if err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
@@ -204,7 +211,7 @@ func handleGetPPDs(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, ppds) models.Respond(conn, req.ID, ppds)
} }
func handleGetClasses(conn net.Conn, req models.Request, manager *Manager) { func handleGetClasses(conn net.Conn, req Request, manager *Manager) {
classes, err := manager.GetClasses() classes, err := manager.GetClasses()
if err != nil { if err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
@@ -213,41 +220,41 @@ func handleGetClasses(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, classes) models.Respond(conn, req.ID, classes)
} }
func handleCreatePrinter(conn net.Conn, req models.Request, manager *Manager) { func handleCreatePrinter(conn net.Conn, req Request, manager *Manager) {
name, err := params.StringNonEmpty(req.Params, "name") name, ok := req.Params["name"].(string)
if err != nil { if !ok || name == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'name' parameter")
return return
} }
deviceURI, err := params.StringNonEmpty(req.Params, "deviceURI") deviceURI, ok := req.Params["deviceURI"].(string)
if err != nil { if !ok || deviceURI == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'deviceURI' parameter")
return return
} }
ppd, err := params.StringNonEmpty(req.Params, "ppd") ppd, ok := req.Params["ppd"].(string)
if err != nil { if !ok || ppd == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'ppd' parameter")
return return
} }
shared := params.BoolOpt(req.Params, "shared", false) shared, _ := req.Params["shared"].(bool)
errorPolicy := params.StringOpt(req.Params, "errorPolicy", "") errorPolicy, _ := req.Params["errorPolicy"].(string)
information := params.StringOpt(req.Params, "information", "") information, _ := req.Params["information"].(string)
location := params.StringOpt(req.Params, "location", "") location, _ := req.Params["location"].(string)
if err := manager.CreatePrinter(name, deviceURI, ppd, shared, errorPolicy, information, location); err != nil { if err := manager.CreatePrinter(name, deviceURI, ppd, shared, errorPolicy, information, location); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "printer created"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "printer created"})
} }
func handleDeletePrinter(conn net.Conn, req models.Request, manager *Manager) { func handleDeletePrinter(conn net.Conn, req Request, manager *Manager) {
printerName, err := params.StringNonEmpty(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok || printerName == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
@@ -255,13 +262,13 @@ func handleDeletePrinter(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "printer deleted"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "printer deleted"})
} }
func handleAcceptJobs(conn net.Conn, req models.Request, manager *Manager) { func handleAcceptJobs(conn net.Conn, req Request, manager *Manager) {
printerName, err := params.StringNonEmpty(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok || printerName == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
@@ -269,13 +276,13 @@ func handleAcceptJobs(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "accepting jobs"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "accepting jobs"})
} }
func handleRejectJobs(conn net.Conn, req models.Request, manager *Manager) { func handleRejectJobs(conn net.Conn, req Request, manager *Manager) {
printerName, err := params.StringNonEmpty(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok || printerName == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
@@ -283,19 +290,19 @@ func handleRejectJobs(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "rejecting jobs"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "rejecting jobs"})
} }
func handleSetPrinterShared(conn net.Conn, req models.Request, manager *Manager) { func handleSetPrinterShared(conn net.Conn, req Request, manager *Manager) {
printerName, err := params.StringNonEmpty(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok || printerName == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
shared, err := params.Bool(req.Params, "shared") shared, ok := req.Params["shared"].(bool)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'shared' parameter")
return return
} }
@@ -303,19 +310,19 @@ func handleSetPrinterShared(conn net.Conn, req models.Request, manager *Manager)
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "sharing updated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "sharing updated"})
} }
func handleSetPrinterLocation(conn net.Conn, req models.Request, manager *Manager) { func handleSetPrinterLocation(conn net.Conn, req Request, manager *Manager) {
printerName, err := params.StringNonEmpty(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok || printerName == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
location, err := params.String(req.Params, "location") location, ok := req.Params["location"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'location' parameter")
return return
} }
@@ -323,19 +330,19 @@ func handleSetPrinterLocation(conn net.Conn, req models.Request, manager *Manage
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "location updated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "location updated"})
} }
func handleSetPrinterInfo(conn net.Conn, req models.Request, manager *Manager) { func handleSetPrinterInfo(conn net.Conn, req Request, manager *Manager) {
printerName, err := params.StringNonEmpty(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok || printerName == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
info, err := params.String(req.Params, "info") info, ok := req.Params["info"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'info' parameter")
return return
} }
@@ -343,33 +350,39 @@ func handleSetPrinterInfo(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "info updated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "info updated"})
} }
func handleMoveJob(conn net.Conn, req models.Request, manager *Manager) { func handleMoveJob(conn net.Conn, req Request, manager *Manager) {
jobID, err := params.Int(req.Params, "jobID") jobIDFloat, ok := req.Params["jobID"].(float64)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'jobID' parameter")
return return
} }
destPrinter, err := params.StringNonEmpty(req.Params, "destPrinter") destPrinter, ok := req.Params["destPrinter"].(string)
if err != nil { if !ok || destPrinter == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'destPrinter' parameter")
return return
} }
if err := manager.MoveJob(jobID, destPrinter); err != nil { if err := manager.MoveJob(int(jobIDFloat), destPrinter); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "job moved"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "job moved"})
} }
func handlePrintTestPage(conn net.Conn, req models.Request, manager *Manager) { type TestPageResult struct {
printerName, err := params.StringNonEmpty(req.Params, "printerName") Success bool `json:"success"`
if err != nil { JobID int `json:"jobId"`
models.RespondError(conn, req.ID, err.Error()) Message string `json:"message"`
}
func handlePrintTestPage(conn net.Conn, req Request, manager *Manager) {
printerName, ok := req.Params["printerName"].(string)
if !ok || printerName == "" {
models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
@@ -381,16 +394,16 @@ func handlePrintTestPage(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, TestPageResult{Success: true, JobID: jobID, Message: "test page queued"}) models.Respond(conn, req.ID, TestPageResult{Success: true, JobID: jobID, Message: "test page queued"})
} }
func handleAddPrinterToClass(conn net.Conn, req models.Request, manager *Manager) { func handleAddPrinterToClass(conn net.Conn, req Request, manager *Manager) {
className, err := params.StringNonEmpty(req.Params, "className") className, ok := req.Params["className"].(string)
if err != nil { if !ok || className == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'className' parameter")
return return
} }
printerName, err := params.StringNonEmpty(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok || printerName == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
@@ -398,19 +411,19 @@ func handleAddPrinterToClass(conn net.Conn, req models.Request, manager *Manager
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "printer added to class"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "printer added to class"})
} }
func handleRemovePrinterFromClass(conn net.Conn, req models.Request, manager *Manager) { func handleRemovePrinterFromClass(conn net.Conn, req Request, manager *Manager) {
className, err := params.StringNonEmpty(req.Params, "className") className, ok := req.Params["className"].(string)
if err != nil { if !ok || className == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'className' parameter")
return return
} }
printerName, err := params.StringNonEmpty(req.Params, "printerName") printerName, ok := req.Params["printerName"].(string)
if err != nil { if !ok || printerName == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return return
} }
@@ -418,13 +431,13 @@ func handleRemovePrinterFromClass(conn net.Conn, req models.Request, manager *Ma
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "printer removed from class"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "printer removed from class"})
} }
func handleDeleteClass(conn net.Conn, req models.Request, manager *Manager) { func handleDeleteClass(conn net.Conn, req Request, manager *Manager) {
className, err := params.StringNonEmpty(req.Params, "className") className, ok := req.Params["className"].(string)
if err != nil { if !ok || className == "" {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'className' parameter")
return return
} }
@@ -432,35 +445,38 @@ func handleDeleteClass(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "class deleted"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "class deleted"})
} }
func handleRestartJob(conn net.Conn, req models.Request, manager *Manager) { func handleRestartJob(conn net.Conn, req Request, manager *Manager) {
jobID, err := params.Int(req.Params, "jobID") jobIDFloat, ok := req.Params["jobID"].(float64)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'jobID' parameter")
return return
} }
if err := manager.RestartJob(jobID); err != nil { if err := manager.RestartJob(int(jobIDFloat)); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "job restarted"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "job restarted"})
} }
func handleHoldJob(conn net.Conn, req models.Request, manager *Manager) { func handleHoldJob(conn net.Conn, req Request, manager *Manager) {
jobID, err := params.Int(req.Params, "jobID") jobIDFloat, ok := req.Params["jobID"].(float64)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'jobID' parameter")
return return
} }
holdUntil := params.StringOpt(req.Params, "holdUntil", "indefinite") holdUntil, _ := req.Params["holdUntil"].(string)
if holdUntil == "" {
holdUntil = "indefinite"
}
if err := manager.HoldJob(jobID, holdUntil); err != nil { if err := manager.HoldJob(int(jobIDFloat), holdUntil); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "job held"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "job held"})
} }

View File

@@ -43,7 +43,7 @@ func TestHandleGetPrinters(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.getPrinters", Method: "cups.getPrinters",
} }
@@ -68,7 +68,7 @@ func TestHandleGetPrinters_Error(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.getPrinters", Method: "cups.getPrinters",
} }
@@ -100,7 +100,7 @@ func TestHandleGetJobs(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.getJobs", Method: "cups.getJobs",
Params: map[string]any{ Params: map[string]any{
@@ -127,7 +127,7 @@ func TestHandleGetJobs_MissingParam(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.getJobs", Method: "cups.getJobs",
Params: map[string]any{}, Params: map[string]any{},
@@ -152,7 +152,7 @@ func TestHandlePausePrinter(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.pausePrinter", Method: "cups.pausePrinter",
Params: map[string]any{ Params: map[string]any{
@@ -162,7 +162,7 @@ func TestHandlePausePrinter(t *testing.T) {
handlePausePrinter(conn, req, m) handlePausePrinter(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -179,7 +179,7 @@ func TestHandleResumePrinter(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.resumePrinter", Method: "cups.resumePrinter",
Params: map[string]any{ Params: map[string]any{
@@ -189,7 +189,7 @@ func TestHandleResumePrinter(t *testing.T) {
handleResumePrinter(conn, req, m) handleResumePrinter(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -206,7 +206,7 @@ func TestHandleCancelJob(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.cancelJob", Method: "cups.cancelJob",
Params: map[string]any{ Params: map[string]any{
@@ -216,7 +216,7 @@ func TestHandleCancelJob(t *testing.T) {
handleCancelJob(conn, req, m) handleCancelJob(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -233,7 +233,7 @@ func TestHandlePurgeJobs(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.purgeJobs", Method: "cups.purgeJobs",
Params: map[string]any{ Params: map[string]any{
@@ -243,7 +243,7 @@ func TestHandlePurgeJobs(t *testing.T) {
handlePurgeJobs(conn, req, m) handlePurgeJobs(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -260,7 +260,7 @@ func TestHandleRequest_UnknownMethod(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.unknownMethod", Method: "cups.unknownMethod",
} }
@@ -287,7 +287,7 @@ func TestHandleGetDevices(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ID: 1, Method: "cups.getDevices"} req := Request{ID: 1, Method: "cups.getDevices"}
handleGetDevices(conn, req, m) handleGetDevices(conn, req, m)
var resp models.Response[[]Device] var resp models.Response[[]Device]
@@ -309,7 +309,7 @@ func TestHandleGetPPDs(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ID: 1, Method: "cups.getPPDs"} req := Request{ID: 1, Method: "cups.getPPDs"}
handleGetPPDs(conn, req, m) handleGetPPDs(conn, req, m)
var resp models.Response[[]PPD] var resp models.Response[[]PPD]
@@ -332,7 +332,7 @@ func TestHandleGetClasses(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ID: 1, Method: "cups.getClasses"} req := Request{ID: 1, Method: "cups.getClasses"}
handleGetClasses(conn, req, m) handleGetClasses(conn, req, m)
var resp models.Response[[]PrinterClass] var resp models.Response[[]PrinterClass]
@@ -353,7 +353,7 @@ func TestHandleCreatePrinter(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.createPrinter", Method: "cups.createPrinter",
Params: map[string]any{ Params: map[string]any{
@@ -364,7 +364,7 @@ func TestHandleCreatePrinter(t *testing.T) {
} }
handleCreatePrinter(conn, req, m) handleCreatePrinter(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -377,7 +377,7 @@ func TestHandleCreatePrinter_MissingParams(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ID: 1, Method: "cups.createPrinter", Params: map[string]any{}} req := Request{ID: 1, Method: "cups.createPrinter", Params: map[string]any{}}
handleCreatePrinter(conn, req, m) handleCreatePrinter(conn, req, m)
var resp models.Response[any] var resp models.Response[any]
@@ -396,14 +396,14 @@ func TestHandleDeletePrinter(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.deletePrinter", Method: "cups.deletePrinter",
Params: map[string]any{"printerName": "printer1"}, Params: map[string]any{"printerName": "printer1"},
} }
handleDeletePrinter(conn, req, m) handleDeletePrinter(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -419,14 +419,14 @@ func TestHandleAcceptJobs(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.acceptJobs", Method: "cups.acceptJobs",
Params: map[string]any{"printerName": "printer1"}, Params: map[string]any{"printerName": "printer1"},
} }
handleAcceptJobs(conn, req, m) handleAcceptJobs(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -442,14 +442,14 @@ func TestHandleRejectJobs(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.rejectJobs", Method: "cups.rejectJobs",
Params: map[string]any{"printerName": "printer1"}, Params: map[string]any{"printerName": "printer1"},
} }
handleRejectJobs(conn, req, m) handleRejectJobs(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -465,14 +465,14 @@ func TestHandleSetPrinterShared(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.setPrinterShared", Method: "cups.setPrinterShared",
Params: map[string]any{"printerName": "printer1", "shared": true}, Params: map[string]any{"printerName": "printer1", "shared": true},
} }
handleSetPrinterShared(conn, req, m) handleSetPrinterShared(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -488,14 +488,14 @@ func TestHandleSetPrinterLocation(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.setPrinterLocation", Method: "cups.setPrinterLocation",
Params: map[string]any{"printerName": "printer1", "location": "Office"}, Params: map[string]any{"printerName": "printer1", "location": "Office"},
} }
handleSetPrinterLocation(conn, req, m) handleSetPrinterLocation(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -511,14 +511,14 @@ func TestHandleSetPrinterInfo(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.setPrinterInfo", Method: "cups.setPrinterInfo",
Params: map[string]any{"printerName": "printer1", "info": "Main Printer"}, Params: map[string]any{"printerName": "printer1", "info": "Main Printer"},
} }
handleSetPrinterInfo(conn, req, m) handleSetPrinterInfo(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -534,14 +534,14 @@ func TestHandleMoveJob(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.moveJob", Method: "cups.moveJob",
Params: map[string]any{"jobID": float64(1), "destPrinter": "printer2"}, Params: map[string]any{"jobID": float64(1), "destPrinter": "printer2"},
} }
handleMoveJob(conn, req, m) handleMoveJob(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -557,7 +557,7 @@ func TestHandlePrintTestPage(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.printTestPage", Method: "cups.printTestPage",
Params: map[string]any{"printerName": "printer1"}, Params: map[string]any{"printerName": "printer1"},
@@ -581,14 +581,14 @@ func TestHandleAddPrinterToClass(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.addPrinterToClass", Method: "cups.addPrinterToClass",
Params: map[string]any{"className": "office", "printerName": "printer1"}, Params: map[string]any{"className": "office", "printerName": "printer1"},
} }
handleAddPrinterToClass(conn, req, m) handleAddPrinterToClass(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -604,14 +604,14 @@ func TestHandleRemovePrinterFromClass(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.removePrinterFromClass", Method: "cups.removePrinterFromClass",
Params: map[string]any{"className": "office", "printerName": "printer1"}, Params: map[string]any{"className": "office", "printerName": "printer1"},
} }
handleRemovePrinterFromClass(conn, req, m) handleRemovePrinterFromClass(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -627,14 +627,14 @@ func TestHandleDeleteClass(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.deleteClass", Method: "cups.deleteClass",
Params: map[string]any{"className": "office"}, Params: map[string]any{"className": "office"},
} }
handleDeleteClass(conn, req, m) handleDeleteClass(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -650,14 +650,14 @@ func TestHandleRestartJob(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.restartJob", Method: "cups.restartJob",
Params: map[string]any{"jobID": float64(1)}, Params: map[string]any{"jobID": float64(1)},
} }
handleRestartJob(conn, req, m) handleRestartJob(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -673,14 +673,14 @@ func TestHandleHoldJob(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.holdJob", Method: "cups.holdJob",
Params: map[string]any{"jobID": float64(1)}, Params: map[string]any{"jobID": float64(1)},
} }
handleHoldJob(conn, req, m) handleHoldJob(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)
@@ -696,14 +696,14 @@ func TestHandleHoldJob_WithHoldUntil(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf} conn := &mockConn{Buffer: buf}
req := models.Request{ req := Request{
ID: 1, ID: 1,
Method: "cups.holdJob", Method: "cups.holdJob",
Params: map[string]any{"jobID": float64(1), "holdUntil": "no-hold"}, Params: map[string]any{"jobID": float64(1), "holdUntil": "no-hold"},
} }
handleHoldJob(conn, req, m) handleHoldJob(conn, req, m)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp) err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, resp.Result) assert.NotNil(t, resp.Result)

View File

@@ -8,12 +8,18 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
) )
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]any `json:"params,omitempty"`
}
type SuccessResult struct { type SuccessResult struct {
Success bool `json:"success"` Success bool `json:"success"`
Message string `json:"message"` Message string `json:"message"`
} }
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) { func HandleRequest(conn net.Conn, req Request, manager *Manager) {
if manager == nil { if manager == nil {
models.RespondError(conn, req.ID, "dwl manager not initialized") models.RespondError(conn, req.ID, "dwl manager not initialized")
return return
@@ -35,12 +41,12 @@ func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleGetState(conn net.Conn, req models.Request, manager *Manager) { func handleGetState(conn net.Conn, req Request, manager *Manager) {
state := manager.GetState() state := manager.GetState()
models.Respond(conn, req.ID, state) models.Respond(conn, req.ID, state)
} }
func handleSetTags(conn net.Conn, req models.Request, manager *Manager) { func handleSetTags(conn net.Conn, req Request, manager *Manager) {
output, ok := req.Params["output"].(string) output, ok := req.Params["output"].(string)
if !ok { if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'output' parameter") models.RespondError(conn, req.ID, "missing or invalid 'output' parameter")
@@ -67,7 +73,7 @@ func handleSetTags(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "tags set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "tags set"})
} }
func handleSetClientTags(conn net.Conn, req models.Request, manager *Manager) { func handleSetClientTags(conn net.Conn, req Request, manager *Manager) {
output, ok := req.Params["output"].(string) output, ok := req.Params["output"].(string)
if !ok { if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'output' parameter") models.RespondError(conn, req.ID, "missing or invalid 'output' parameter")
@@ -94,7 +100,7 @@ func handleSetClientTags(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "client tags set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "client tags set"})
} }
func handleSetLayout(conn net.Conn, req models.Request, manager *Manager) { func handleSetLayout(conn net.Conn, req Request, manager *Manager) {
output, ok := req.Params["output"].(string) output, ok := req.Params["output"].(string)
if !ok { if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'output' parameter") models.RespondError(conn, req.ID, "missing or invalid 'output' parameter")
@@ -115,7 +121,7 @@ func handleSetLayout(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "layout set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "layout set"})
} }
func handleSubscribe(conn net.Conn, req models.Request, manager *Manager) { func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn) clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID) stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID) defer manager.Unsubscribe(clientID)

View File

@@ -6,15 +6,22 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
) )
func HandleRequest(conn net.Conn, req models.Request, m *Manager) { type Request struct {
ID any `json:"id"`
Method string `json:"method"`
Params map[string]any `json:"params"`
}
func HandleRequest(conn net.Conn, req Request, m *Manager) {
switch req.Method { switch req.Method {
case "evdev.getState": case "evdev.getState":
handleGetState(conn, req, m) handleGetState(conn, req, m)
default: default:
models.RespondError(conn, req.ID, "unknown method: "+req.Method) models.RespondError(conn, req.ID.(int), "unknown method: "+req.Method)
} }
} }
func handleGetState(conn net.Conn, req models.Request, m *Manager) { func handleGetState(conn net.Conn, req Request, m *Manager) {
models.Respond(conn, req.ID, m.GetState()) state := m.GetState()
models.Respond(conn, req.ID.(int), state)
} }

View File

@@ -53,7 +53,7 @@ func TestHandleRequest(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "evdev.getState", Method: "evdev.getState",
Params: map[string]any{}, Params: map[string]any{},
@@ -82,7 +82,7 @@ func TestHandleRequest(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 456, ID: 456,
Method: "evdev.unknownMethod", Method: "evdev.unknownMethod",
Params: map[string]any{}, Params: map[string]any{},
@@ -111,7 +111,7 @@ func TestHandleGetState(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 789, ID: 789,
Method: "evdev.getState", Method: "evdev.getState",
Params: map[string]any{}, Params: map[string]any{},

View File

@@ -306,15 +306,6 @@ func (m *Manager) readAndUpdateCapsLockState(deviceIndex int) {
return return
} }
if len(ledStates) == 0 {
log.Debug("No LED state available (empty map)")
// This means the device either:
// - doesn't support LED reporting at all, or
// - the kernel returned an empty state
return
}
capsLockState := ledStates[ledCapslockKey] capsLockState := ledStates[ledCapslockKey]
m.updateCapsLockStateDirect(capsLockState) m.updateCapsLockStateDirect(capsLockState)
} }

View File

@@ -8,12 +8,18 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
) )
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]any `json:"params,omitempty"`
}
type SuccessResult struct { type SuccessResult struct {
Success bool `json:"success"` Success bool `json:"success"`
Message string `json:"message"` Message string `json:"message"`
} }
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) { func HandleRequest(conn net.Conn, req Request, manager *Manager) {
if manager == nil { if manager == nil {
models.RespondError(conn, req.ID, "extworkspace manager not initialized") models.RespondError(conn, req.ID, "extworkspace manager not initialized")
return return
@@ -37,12 +43,12 @@ func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleGetState(conn net.Conn, req models.Request, manager *Manager) { func handleGetState(conn net.Conn, req Request, manager *Manager) {
state := manager.GetState() state := manager.GetState()
models.Respond(conn, req.ID, state) models.Respond(conn, req.ID, state)
} }
func handleActivateWorkspace(conn net.Conn, req models.Request, manager *Manager) { func handleActivateWorkspace(conn net.Conn, req Request, manager *Manager) {
groupID, ok := req.Params["groupID"].(string) groupID, ok := req.Params["groupID"].(string)
if !ok { if !ok {
groupID = "" groupID = ""
@@ -62,7 +68,7 @@ func handleActivateWorkspace(conn net.Conn, req models.Request, manager *Manager
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "workspace activated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "workspace activated"})
} }
func handleDeactivateWorkspace(conn net.Conn, req models.Request, manager *Manager) { func handleDeactivateWorkspace(conn net.Conn, req Request, manager *Manager) {
groupID, ok := req.Params["groupID"].(string) groupID, ok := req.Params["groupID"].(string)
if !ok { if !ok {
groupID = "" groupID = ""
@@ -82,7 +88,7 @@ func handleDeactivateWorkspace(conn net.Conn, req models.Request, manager *Manag
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "workspace deactivated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "workspace deactivated"})
} }
func handleRemoveWorkspace(conn net.Conn, req models.Request, manager *Manager) { func handleRemoveWorkspace(conn net.Conn, req Request, manager *Manager) {
groupID, ok := req.Params["groupID"].(string) groupID, ok := req.Params["groupID"].(string)
if !ok { if !ok {
groupID = "" groupID = ""
@@ -102,7 +108,7 @@ func handleRemoveWorkspace(conn net.Conn, req models.Request, manager *Manager)
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "workspace removed"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "workspace removed"})
} }
func handleCreateWorkspace(conn net.Conn, req models.Request, manager *Manager) { func handleCreateWorkspace(conn net.Conn, req Request, manager *Manager) {
groupID, ok := req.Params["groupID"].(string) groupID, ok := req.Params["groupID"].(string)
if !ok { if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'groupID' parameter") models.RespondError(conn, req.ID, "missing or invalid 'groupID' parameter")
@@ -123,7 +129,7 @@ func handleCreateWorkspace(conn net.Conn, req models.Request, manager *Manager)
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "workspace create requested"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "workspace create requested"})
} }
func handleSubscribe(conn net.Conn, req models.Request, manager *Manager) { func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn) clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID) stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID) defer manager.Unsubscribe(clientID)

View File

@@ -5,10 +5,21 @@ import (
"net" "net"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/params"
) )
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) { type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]any `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
Value string `json:"value,omitempty"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method { switch req.Method {
case "freedesktop.getState": case "freedesktop.getState":
handleGetState(conn, req, manager) handleGetState(conn, req, manager)
@@ -33,14 +44,15 @@ func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleGetState(conn net.Conn, req models.Request, manager *Manager) { func handleGetState(conn net.Conn, req Request, manager *Manager) {
models.Respond(conn, req.ID, manager.GetState()) state := manager.GetState()
models.Respond(conn, req.ID, state)
} }
func handleSetIconFile(conn net.Conn, req models.Request, manager *Manager) { func handleSetIconFile(conn net.Conn, req Request, manager *Manager) {
iconPath, err := params.String(req.Params, "path") iconPath, ok := req.Params["path"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'path' parameter")
return return
} }
@@ -49,13 +61,13 @@ func handleSetIconFile(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "icon file set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "icon file set"})
} }
func handleSetRealName(conn net.Conn, req models.Request, manager *Manager) { func handleSetRealName(conn net.Conn, req Request, manager *Manager) {
name, err := params.String(req.Params, "name") name, ok := req.Params["name"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'name' parameter")
return return
} }
@@ -64,13 +76,13 @@ func handleSetRealName(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "real name set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "real name set"})
} }
func handleSetEmail(conn net.Conn, req models.Request, manager *Manager) { func handleSetEmail(conn net.Conn, req Request, manager *Manager) {
email, err := params.String(req.Params, "email") email, ok := req.Params["email"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'email' parameter")
return return
} }
@@ -79,13 +91,13 @@ func handleSetEmail(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "email set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "email set"})
} }
func handleSetLanguage(conn net.Conn, req models.Request, manager *Manager) { func handleSetLanguage(conn net.Conn, req Request, manager *Manager) {
language, err := params.String(req.Params, "language") language, ok := req.Params["language"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'language' parameter")
return return
} }
@@ -94,13 +106,13 @@ func handleSetLanguage(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "language set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "language set"})
} }
func handleSetLocation(conn net.Conn, req models.Request, manager *Manager) { func handleSetLocation(conn net.Conn, req Request, manager *Manager) {
location, err := params.String(req.Params, "location") location, ok := req.Params["location"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'location' parameter")
return return
} }
@@ -109,13 +121,13 @@ func handleSetLocation(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "location set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "location set"})
} }
func handleGetUserIconFile(conn net.Conn, req models.Request, manager *Manager) { func handleGetUserIconFile(conn net.Conn, req Request, manager *Manager) {
username, err := params.String(req.Params, "username") username, ok := req.Params["username"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'username' parameter")
return return
} }
@@ -125,10 +137,10 @@ func handleGetUserIconFile(conn net.Conn, req models.Request, manager *Manager)
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Value: iconFile}) models.Respond(conn, req.ID, SuccessResult{Success: true, Value: iconFile})
} }
func handleGetColorScheme(conn net.Conn, req models.Request, manager *Manager) { func handleGetColorScheme(conn net.Conn, req Request, manager *Manager) {
if err := manager.updateSettingsState(); err != nil { if err := manager.updateSettingsState(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
@@ -138,10 +150,10 @@ func handleGetColorScheme(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, map[string]uint32{"colorScheme": state.Settings.ColorScheme}) models.Respond(conn, req.ID, map[string]uint32{"colorScheme": state.Settings.ColorScheme})
} }
func handleSetIconTheme(conn net.Conn, req models.Request, manager *Manager) { func handleSetIconTheme(conn net.Conn, req Request, manager *Manager) {
iconTheme, err := params.String(req.Params, "iconTheme") iconTheme, ok := req.Params["iconTheme"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'iconTheme' parameter")
return return
} }
@@ -150,5 +162,5 @@ func handleSetIconTheme(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "icon theme set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "icon theme set"})
} }

View File

@@ -74,10 +74,10 @@ func TestRespondError_Freedesktop(t *testing.T) {
func TestRespond_Freedesktop(t *testing.T) { func TestRespond_Freedesktop(t *testing.T) {
conn := newMockNetConn() conn := newMockNetConn()
result := models.SuccessResult{Success: true, Message: "test"} result := SuccessResult{Success: true, Message: "test"}
models.Respond(conn, 123, result) models.Respond(conn, 123, result)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -106,7 +106,7 @@ func TestHandleGetState(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "freedesktop.getState"} req := Request{ID: 123, Method: "freedesktop.getState"}
handleGetState(conn, req, manager) handleGetState(conn, req, manager)
@@ -131,7 +131,7 @@ func TestHandleSetIconFile(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.setIconFile", Method: "freedesktop.accounts.setIconFile",
Params: map[string]any{}, Params: map[string]any{},
@@ -164,7 +164,7 @@ func TestHandleSetIconFile(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.setIconFile", Method: "freedesktop.accounts.setIconFile",
Params: map[string]any{ Params: map[string]any{
@@ -174,7 +174,7 @@ func TestHandleSetIconFile(t *testing.T) {
handleSetIconFile(conn, req, manager) handleSetIconFile(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -196,7 +196,7 @@ func TestHandleSetIconFile(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.setIconFile", Method: "freedesktop.accounts.setIconFile",
Params: map[string]any{ Params: map[string]any{
@@ -206,7 +206,7 @@ func TestHandleSetIconFile(t *testing.T) {
handleSetIconFile(conn, req, manager) handleSetIconFile(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -223,7 +223,7 @@ func TestHandleSetRealName(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.setRealName", Method: "freedesktop.accounts.setRealName",
Params: map[string]any{}, Params: map[string]any{},
@@ -256,7 +256,7 @@ func TestHandleSetRealName(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.setRealName", Method: "freedesktop.accounts.setRealName",
Params: map[string]any{ Params: map[string]any{
@@ -266,7 +266,7 @@ func TestHandleSetRealName(t *testing.T) {
handleSetRealName(conn, req, manager) handleSetRealName(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -286,7 +286,7 @@ func TestHandleSetEmail(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.setEmail", Method: "freedesktop.accounts.setEmail",
Params: map[string]any{}, Params: map[string]any{},
@@ -319,7 +319,7 @@ func TestHandleSetEmail(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.setEmail", Method: "freedesktop.accounts.setEmail",
Params: map[string]any{ Params: map[string]any{
@@ -329,7 +329,7 @@ func TestHandleSetEmail(t *testing.T) {
handleSetEmail(conn, req, manager) handleSetEmail(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -349,7 +349,7 @@ func TestHandleSetLanguage(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.setLanguage", Method: "freedesktop.accounts.setLanguage",
Params: map[string]any{}, Params: map[string]any{},
@@ -374,7 +374,7 @@ func TestHandleSetLocation(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.setLocation", Method: "freedesktop.accounts.setLocation",
Params: map[string]any{}, Params: map[string]any{},
@@ -399,7 +399,7 @@ func TestHandleGetUserIconFile(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.getUserIconFile", Method: "freedesktop.accounts.getUserIconFile",
Params: map[string]any{}, Params: map[string]any{},
@@ -426,7 +426,7 @@ func TestHandleGetUserIconFile(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.accounts.getUserIconFile", Method: "freedesktop.accounts.getUserIconFile",
Params: map[string]any{ Params: map[string]any{
@@ -436,7 +436,7 @@ func TestHandleGetUserIconFile(t *testing.T) {
handleGetUserIconFile(conn, req, manager) handleGetUserIconFile(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -457,7 +457,7 @@ func TestHandleGetColorScheme(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "freedesktop.settings.getColorScheme"} req := Request{ID: 123, Method: "freedesktop.settings.getColorScheme"}
handleGetColorScheme(conn, req, manager) handleGetColorScheme(conn, req, manager)
@@ -488,7 +488,7 @@ func TestHandleGetColorScheme(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "freedesktop.settings.getColorScheme"} req := Request{ID: 123, Method: "freedesktop.settings.getColorScheme"}
handleGetColorScheme(conn, req, manager) handleGetColorScheme(conn, req, manager)
@@ -516,7 +516,7 @@ func TestHandleRequest(t *testing.T) {
t.Run("unknown method", func(t *testing.T) { t.Run("unknown method", func(t *testing.T) {
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.unknown", Method: "freedesktop.unknown",
} }
@@ -533,7 +533,7 @@ func TestHandleRequest(t *testing.T) {
t.Run("valid method - getState", func(t *testing.T) { t.Run("valid method - getState", func(t *testing.T) {
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "freedesktop.getState", Method: "freedesktop.getState",
} }
@@ -561,7 +561,7 @@ func TestHandleRequest(t *testing.T) {
for _, method := range tests { for _, method := range tests {
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: method, Method: method,
Params: map[string]any{}, Params: map[string]any{},

View File

@@ -6,10 +6,20 @@ import (
"net" "net"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/params"
) )
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) { type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]any `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method { switch req.Method {
case "loginctl.getState": case "loginctl.getState":
handleGetState(conn, req, manager) handleGetState(conn, req, manager)
@@ -36,38 +46,39 @@ func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleGetState(conn net.Conn, req models.Request, manager *Manager) { func handleGetState(conn net.Conn, req Request, manager *Manager) {
models.Respond(conn, req.ID, manager.GetState()) state := manager.GetState()
models.Respond(conn, req.ID, state)
} }
func handleLock(conn net.Conn, req models.Request, manager *Manager) { func handleLock(conn net.Conn, req Request, manager *Manager) {
if err := manager.Lock(); err != nil { if err := manager.Lock(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "locked"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "locked"})
} }
func handleUnlock(conn net.Conn, req models.Request, manager *Manager) { func handleUnlock(conn net.Conn, req Request, manager *Manager) {
if err := manager.Unlock(); err != nil { if err := manager.Unlock(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "unlocked"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "unlocked"})
} }
func handleActivate(conn net.Conn, req models.Request, manager *Manager) { func handleActivate(conn net.Conn, req Request, manager *Manager) {
if err := manager.Activate(); err != nil { if err := manager.Activate(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "activated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "activated"})
} }
func handleSetIdleHint(conn net.Conn, req models.Request, manager *Manager) { func handleSetIdleHint(conn net.Conn, req Request, manager *Manager) {
idle, err := params.Bool(req.Params, "idle") idle, ok := req.Params["idle"].(bool)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'idle' parameter")
return return
} }
@@ -75,32 +86,32 @@ func handleSetIdleHint(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "idle hint set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "idle hint set"})
} }
func handleSetLockBeforeSuspend(conn net.Conn, req models.Request, manager *Manager) { func handleSetLockBeforeSuspend(conn net.Conn, req Request, manager *Manager) {
enabled, err := params.Bool(req.Params, "enabled") enabled, ok := req.Params["enabled"].(bool)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'enabled' parameter")
return return
} }
manager.SetLockBeforeSuspend(enabled) manager.SetLockBeforeSuspend(enabled)
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "lock before suspend set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "lock before suspend set"})
} }
func handleSetSleepInhibitorEnabled(conn net.Conn, req models.Request, manager *Manager) { func handleSetSleepInhibitorEnabled(conn net.Conn, req Request, manager *Manager) {
enabled, err := params.Bool(req.Params, "enabled") enabled, ok := req.Params["enabled"].(bool)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'enabled' parameter")
return return
} }
manager.SetSleepInhibitorEnabled(enabled) manager.SetSleepInhibitorEnabled(enabled)
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "sleep inhibitor setting updated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "sleep inhibitor setting updated"})
} }
func handleLockerReady(conn net.Conn, req models.Request, manager *Manager) { func handleLockerReady(conn net.Conn, req Request, manager *Manager) {
manager.lockTimerMu.Lock() manager.lockTimerMu.Lock()
if manager.lockTimer != nil { if manager.lockTimer != nil {
manager.lockTimer.Stop() manager.lockTimer.Stop()
@@ -114,18 +125,18 @@ func handleLockerReady(conn net.Conn, req models.Request, manager *Manager) {
if manager.inSleepCycle.Load() { if manager.inSleepCycle.Load() {
manager.signalLockerReady() manager.signalLockerReady()
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "ok"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "ok"})
} }
func handleTerminate(conn net.Conn, req models.Request, manager *Manager) { func handleTerminate(conn net.Conn, req Request, manager *Manager) {
if err := manager.Terminate(); err != nil { if err := manager.Terminate(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "terminated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "terminated"})
} }
func handleSubscribe(conn net.Conn, req models.Request, manager *Manager) { func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn) clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID) stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID) defer manager.Unsubscribe(clientID)

View File

@@ -58,10 +58,10 @@ func TestRespondError_Loginctl(t *testing.T) {
func TestRespond_Loginctl(t *testing.T) { func TestRespond_Loginctl(t *testing.T) {
conn := newMockNetConn() conn := newMockNetConn()
result := models.SuccessResult{Success: true, Message: "test"} result := SuccessResult{Success: true, Message: "test"}
models.Respond(conn, 123, result) models.Respond(conn, 123, result)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -86,7 +86,7 @@ func TestHandleGetState(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "loginctl.getState"} req := Request{ID: 123, Method: "loginctl.getState"}
handleGetState(conn, req, manager) handleGetState(conn, req, manager)
@@ -115,10 +115,10 @@ func TestHandleLock(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "loginctl.lock"} req := Request{ID: 123, Method: "loginctl.lock"}
handleLock(conn, req, manager) handleLock(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -141,10 +141,10 @@ func TestHandleLock(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "loginctl.lock"} req := Request{ID: 123, Method: "loginctl.lock"}
handleLock(conn, req, manager) handleLock(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -166,10 +166,10 @@ func TestHandleUnlock(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "loginctl.unlock"} req := Request{ID: 123, Method: "loginctl.unlock"}
handleUnlock(conn, req, manager) handleUnlock(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -192,10 +192,10 @@ func TestHandleUnlock(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "loginctl.unlock"} req := Request{ID: 123, Method: "loginctl.unlock"}
handleUnlock(conn, req, manager) handleUnlock(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -217,10 +217,10 @@ func TestHandleActivate(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "loginctl.activate"} req := Request{ID: 123, Method: "loginctl.activate"}
handleActivate(conn, req, manager) handleActivate(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -243,10 +243,10 @@ func TestHandleActivate(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "loginctl.activate"} req := Request{ID: 123, Method: "loginctl.activate"}
handleActivate(conn, req, manager) handleActivate(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -263,7 +263,7 @@ func TestHandleSetIdleHint(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "loginctl.setIdleHint", Method: "loginctl.setIdleHint",
Params: map[string]any{}, Params: map[string]any{},
@@ -291,7 +291,7 @@ func TestHandleSetIdleHint(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "loginctl.setIdleHint", Method: "loginctl.setIdleHint",
Params: map[string]any{ Params: map[string]any{
@@ -301,7 +301,7 @@ func TestHandleSetIdleHint(t *testing.T) {
handleSetIdleHint(conn, req, manager) handleSetIdleHint(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -324,7 +324,7 @@ func TestHandleSetIdleHint(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "loginctl.setIdleHint", Method: "loginctl.setIdleHint",
Params: map[string]any{ Params: map[string]any{
@@ -334,7 +334,7 @@ func TestHandleSetIdleHint(t *testing.T) {
handleSetIdleHint(conn, req, manager) handleSetIdleHint(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -356,10 +356,10 @@ func TestHandleTerminate(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "loginctl.terminate"} req := Request{ID: 123, Method: "loginctl.terminate"}
handleTerminate(conn, req, manager) handleTerminate(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -382,10 +382,10 @@ func TestHandleTerminate(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "loginctl.terminate"} req := Request{ID: 123, Method: "loginctl.terminate"}
handleTerminate(conn, req, manager) handleTerminate(conn, req, manager)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -405,7 +405,7 @@ func TestHandleRequest(t *testing.T) {
t.Run("unknown method", func(t *testing.T) { t.Run("unknown method", func(t *testing.T) {
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "loginctl.unknown", Method: "loginctl.unknown",
} }
@@ -422,7 +422,7 @@ func TestHandleRequest(t *testing.T) {
t.Run("valid method - getState", func(t *testing.T) { t.Run("valid method - getState", func(t *testing.T) {
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "loginctl.getState", Method: "loginctl.getState",
} }
@@ -445,7 +445,7 @@ func TestHandleRequest(t *testing.T) {
manager.sessionObj = mockSessionObj manager.sessionObj = mockSessionObj
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "loginctl.lock", Method: "loginctl.lock",
} }
@@ -470,7 +470,7 @@ func TestHandleSubscribe(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "loginctl.subscribe"} req := Request{ID: 123, Method: "loginctl.subscribe"}
done := make(chan bool) done := make(chan bool)
go func() { go func() {

View File

@@ -29,9 +29,3 @@ func Respond[T any](conn net.Conn, id int, result T) {
resp := Response[T]{ID: id, Result: &result} resp := Response[T]{ID: id, Result: &result}
json.NewEncoder(conn).Encode(resp) json.NewEncoder(conn).Encode(resp)
} }
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
Value string `json:"value,omitempty"`
}

View File

@@ -7,10 +7,20 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/log" "github.com/AvengeMedia/DankMaterialShell/core/internal/log"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/params"
) )
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) { type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]any `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method { switch req.Method {
case "network.getState": case "network.getState":
handleGetState(conn, req, manager) handleGetState(conn, req, manager)
@@ -79,22 +89,32 @@ func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleCredentialsSubmit(conn net.Conn, req models.Request, manager *Manager) { func handleCredentialsSubmit(conn net.Conn, req Request, manager *Manager) {
token, err := params.String(req.Params, "token") token, ok := req.Params["token"].(string)
if err != nil { if !ok {
log.Warnf("handleCredentialsSubmit: missing or invalid token parameter") log.Warnf("handleCredentialsSubmit: missing or invalid token parameter")
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'token' parameter")
return return
} }
secrets, err := params.StringMap(req.Params, "secrets") secretsRaw, ok := req.Params["secrets"].(map[string]any)
if err != nil { if !ok {
log.Warnf("handleCredentialsSubmit: missing or invalid secrets parameter") log.Warnf("handleCredentialsSubmit: missing or invalid secrets parameter")
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'secrets' parameter")
return return
} }
save := params.BoolOpt(req.Params, "save", true) secrets := make(map[string]string)
for k, v := range secretsRaw {
if str, ok := v.(string); ok {
secrets[k] = str
}
}
save := true
if saveParam, ok := req.Params["save"].(bool); ok {
save = saveParam
}
if err := manager.SubmitCredentials(token, secrets, save); err != nil { if err := manager.SubmitCredentials(token, secrets, save); err != nil {
log.Warnf("handleCredentialsSubmit: failed to submit credentials: %v", err) log.Warnf("handleCredentialsSubmit: failed to submit credentials: %v", err)
@@ -103,13 +123,13 @@ func handleCredentialsSubmit(conn net.Conn, req models.Request, manager *Manager
} }
log.Infof("handleCredentialsSubmit: credentials submitted successfully") log.Infof("handleCredentialsSubmit: credentials submitted successfully")
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "credentials submitted"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "credentials submitted"})
} }
func handleCredentialsCancel(conn net.Conn, req models.Request, manager *Manager) { func handleCredentialsCancel(conn net.Conn, req Request, manager *Manager) {
token, err := params.String(req.Params, "token") token, ok := req.Params["token"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'token' parameter")
return return
} }
@@ -118,15 +138,16 @@ func handleCredentialsCancel(conn net.Conn, req models.Request, manager *Manager
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "credentials cancelled"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "credentials cancelled"})
} }
func handleGetState(conn net.Conn, req models.Request, manager *Manager) { func handleGetState(conn net.Conn, req Request, manager *Manager) {
models.Respond(conn, req.ID, manager.GetState()) state := manager.GetState()
models.Respond(conn, req.ID, state)
} }
func handleScanWiFi(conn net.Conn, req models.Request, manager *Manager) { func handleScanWiFi(conn net.Conn, req Request, manager *Manager) {
device := params.StringOpt(req.Params, "device", "") device, _ := req.Params["device"].(string)
var err error var err error
if device != "" { if device != "" {
err = manager.ScanWiFiDevice(device) err = manager.ScanWiFiDevice(device)
@@ -137,25 +158,33 @@ func handleScanWiFi(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "scanning"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "scanning"})
} }
func handleGetWiFiNetworks(conn net.Conn, req models.Request, manager *Manager) { func handleGetWiFiNetworks(conn net.Conn, req Request, manager *Manager) {
models.Respond(conn, req.ID, manager.GetWiFiNetworks()) networks := manager.GetWiFiNetworks()
models.Respond(conn, req.ID, networks)
} }
func handleConnectWiFi(conn net.Conn, req models.Request, manager *Manager) { func handleConnectWiFi(conn net.Conn, req Request, manager *Manager) {
ssid, err := params.String(req.Params, "ssid") ssid, ok := req.Params["ssid"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'ssid' parameter")
return return
} }
var connReq ConnectionRequest var connReq ConnectionRequest
connReq.SSID = ssid connReq.SSID = ssid
connReq.Password = params.StringOpt(req.Params, "password", "")
connReq.Username = params.StringOpt(req.Params, "username", "") if password, ok := req.Params["password"].(string); ok {
connReq.Device = params.StringOpt(req.Params, "device", "") connReq.Password = password
}
if username, ok := req.Params["username"].(string); ok {
connReq.Username = username
}
if device, ok := req.Params["device"].(string); ok {
connReq.Device = device
}
if interactive, ok := req.Params["interactive"].(bool); ok { if interactive, ok := req.Params["interactive"].(bool); ok {
connReq.Interactive = interactive connReq.Interactive = interactive
@@ -177,14 +206,27 @@ func handleConnectWiFi(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
connReq.AnonymousIdentity = params.StringOpt(req.Params, "anonymousIdentity", "") if anonymousIdentity, ok := req.Params["anonymousIdentity"].(string); ok {
connReq.DomainSuffixMatch = params.StringOpt(req.Params, "domainSuffixMatch", "") connReq.AnonymousIdentity = anonymousIdentity
connReq.EAPMethod = params.StringOpt(req.Params, "eapMethod", "") }
connReq.Phase2Auth = params.StringOpt(req.Params, "phase2Auth", "") if domainSuffixMatch, ok := req.Params["domainSuffixMatch"].(string); ok {
connReq.CACertPath = params.StringOpt(req.Params, "caCertPath", "") connReq.DomainSuffixMatch = domainSuffixMatch
connReq.ClientCertPath = params.StringOpt(req.Params, "clientCertPath", "") }
connReq.PrivateKeyPath = params.StringOpt(req.Params, "privateKeyPath", "") if eapMethod, ok := req.Params["eapMethod"].(string); ok {
connReq.EAPMethod = eapMethod
}
if phase2Auth, ok := req.Params["phase2Auth"].(string); ok {
connReq.Phase2Auth = phase2Auth
}
if caCertPath, ok := req.Params["caCertPath"].(string); ok {
connReq.CACertPath = caCertPath
}
if clientCertPath, ok := req.Params["clientCertPath"].(string); ok {
connReq.ClientCertPath = clientCertPath
}
if privateKeyPath, ok := req.Params["privateKeyPath"].(string); ok {
connReq.PrivateKeyPath = privateKeyPath
}
if useSystemCACerts, ok := req.Params["useSystemCACerts"].(bool); ok { if useSystemCACerts, ok := req.Params["useSystemCACerts"].(bool); ok {
connReq.UseSystemCACerts = &useSystemCACerts connReq.UseSystemCACerts = &useSystemCACerts
} }
@@ -194,11 +236,11 @@ func handleConnectWiFi(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "connecting"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "connecting"})
} }
func handleDisconnectWiFi(conn net.Conn, req models.Request, manager *Manager) { func handleDisconnectWiFi(conn net.Conn, req Request, manager *Manager) {
device := params.StringOpt(req.Params, "device", "") device, _ := req.Params["device"].(string)
var err error var err error
if device != "" { if device != "" {
err = manager.DisconnectWiFiDevice(device) err = manager.DisconnectWiFiDevice(device)
@@ -209,13 +251,13 @@ func handleDisconnectWiFi(conn net.Conn, req models.Request, manager *Manager) {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "disconnected"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "disconnected"})
} }
func handleForgetWiFi(conn net.Conn, req models.Request, manager *Manager) { func handleForgetWiFi(conn net.Conn, req Request, manager *Manager) {
ssid, err := params.String(req.Params, "ssid") ssid, ok := req.Params["ssid"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'ssid' parameter")
return return
} }
@@ -224,10 +266,10 @@ func handleForgetWiFi(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "forgotten"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "forgotten"})
} }
func handleToggleWiFi(conn net.Conn, req models.Request, manager *Manager) { func handleToggleWiFi(conn net.Conn, req Request, manager *Manager) {
if err := manager.ToggleWiFi(); err != nil { if err := manager.ToggleWiFi(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
@@ -237,7 +279,7 @@ func handleToggleWiFi(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, map[string]bool{"enabled": state.WiFiEnabled}) models.Respond(conn, req.ID, map[string]bool{"enabled": state.WiFiEnabled})
} }
func handleEnableWiFi(conn net.Conn, req models.Request, manager *Manager) { func handleEnableWiFi(conn net.Conn, req Request, manager *Manager) {
if err := manager.EnableWiFi(); err != nil { if err := manager.EnableWiFi(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
@@ -245,7 +287,7 @@ func handleEnableWiFi(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, map[string]bool{"enabled": true}) models.Respond(conn, req.ID, map[string]bool{"enabled": true})
} }
func handleDisableWiFi(conn net.Conn, req models.Request, manager *Manager) { func handleDisableWiFi(conn net.Conn, req Request, manager *Manager) {
if err := manager.DisableWiFi(); err != nil { if err := manager.DisableWiFi(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
@@ -253,29 +295,29 @@ func handleDisableWiFi(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, map[string]bool{"enabled": false}) models.Respond(conn, req.ID, map[string]bool{"enabled": false})
} }
func handleConnectEthernetSpecificConfig(conn net.Conn, req models.Request, manager *Manager) { func handleConnectEthernetSpecificConfig(conn net.Conn, req Request, manager *Manager) {
uuid, err := params.String(req.Params, "uuid") uuid, ok := req.Params["uuid"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'uuid' parameter")
return return
} }
if err := manager.activateConnection(uuid); err != nil { if err := manager.activateConnection(uuid); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "connecting"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "connecting"})
} }
func handleConnectEthernet(conn net.Conn, req models.Request, manager *Manager) { func handleConnectEthernet(conn net.Conn, req Request, manager *Manager) {
if err := manager.ConnectEthernet(); err != nil { if err := manager.ConnectEthernet(); err != nil {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "connecting"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "connecting"})
} }
func handleDisconnectEthernet(conn net.Conn, req models.Request, manager *Manager) { func handleDisconnectEthernet(conn net.Conn, req Request, manager *Manager) {
device := params.StringOpt(req.Params, "device", "") device, _ := req.Params["device"].(string)
var err error var err error
if device != "" { if device != "" {
err = manager.DisconnectEthernetDevice(device) err = manager.DisconnectEthernetDevice(device)
@@ -286,13 +328,13 @@ func handleDisconnectEthernet(conn net.Conn, req models.Request, manager *Manage
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, err.Error())
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "disconnected"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "disconnected"})
} }
func handleSetPreference(conn net.Conn, req models.Request, manager *Manager) { func handleSetPreference(conn net.Conn, req Request, manager *Manager) {
preference, err := params.String(req.Params, "preference") preference, ok := req.Params["preference"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'preference' parameter")
return return
} }
@@ -304,10 +346,10 @@ func handleSetPreference(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, map[string]string{"preference": preference}) models.Respond(conn, req.ID, map[string]string{"preference": preference})
} }
func handleGetNetworkInfo(conn net.Conn, req models.Request, manager *Manager) { func handleGetNetworkInfo(conn net.Conn, req Request, manager *Manager) {
ssid, err := params.String(req.Params, "ssid") ssid, ok := req.Params["ssid"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'ssid' parameter")
return return
} }
@@ -320,10 +362,10 @@ func handleGetNetworkInfo(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, network) models.Respond(conn, req.ID, network)
} }
func handleGetWiredNetworkInfo(conn net.Conn, req models.Request, manager *Manager) { func handleGetWiredNetworkInfo(conn net.Conn, req Request, manager *Manager) {
uuid, err := params.String(req.Params, "uuid") uuid, ok := req.Params["uuid"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'uuid' parameter")
return return
} }
@@ -336,7 +378,7 @@ func handleGetWiredNetworkInfo(conn net.Conn, req models.Request, manager *Manag
models.Respond(conn, req.ID, network) models.Respond(conn, req.ID, network)
} }
func handleSubscribe(conn net.Conn, req models.Request, manager *Manager) { func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn) clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID) stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID) defer manager.Unsubscribe(clientID)
@@ -366,7 +408,7 @@ func handleSubscribe(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleListVPNProfiles(conn net.Conn, req models.Request, manager *Manager) { func handleListVPNProfiles(conn net.Conn, req Request, manager *Manager) {
profiles, err := manager.ListVPNProfiles() profiles, err := manager.ListVPNProfiles()
if err != nil { if err != nil {
log.Warnf("handleListVPNProfiles: failed to list profiles: %v", err) log.Warnf("handleListVPNProfiles: failed to list profiles: %v", err)
@@ -377,7 +419,7 @@ func handleListVPNProfiles(conn net.Conn, req models.Request, manager *Manager)
models.Respond(conn, req.ID, profiles) models.Respond(conn, req.ID, profiles)
} }
func handleListActiveVPN(conn net.Conn, req models.Request, manager *Manager) { func handleListActiveVPN(conn net.Conn, req Request, manager *Manager) {
active, err := manager.ListActiveVPN() active, err := manager.ListActiveVPN()
if err != nil { if err != nil {
log.Warnf("handleListActiveVPN: failed to list active VPNs: %v", err) log.Warnf("handleListActiveVPN: failed to list active VPNs: %v", err)
@@ -388,15 +430,27 @@ func handleListActiveVPN(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, active) models.Respond(conn, req.ID, active)
} }
func handleConnectVPN(conn net.Conn, req models.Request, manager *Manager) { func handleConnectVPN(conn net.Conn, req Request, manager *Manager) {
uuidOrName, ok := params.StringAlt(req.Params, "uuidOrName", "name", "uuid") uuidOrName, ok := req.Params["uuidOrName"].(string)
if !ok { if !ok {
log.Warnf("handleConnectVPN: missing uuidOrName/name/uuid parameter") name, nameOk := req.Params["name"].(string)
models.RespondError(conn, req.ID, "missing 'uuidOrName', 'name', or 'uuid' parameter") uuid, uuidOk := req.Params["uuid"].(string)
return if nameOk {
uuidOrName = name
} else if uuidOk {
uuidOrName = uuid
} else {
log.Warnf("handleConnectVPN: missing uuidOrName/name/uuid parameter")
models.RespondError(conn, req.ID, "missing 'uuidOrName', 'name', or 'uuid' parameter")
return
}
} }
singleActive := params.BoolOpt(req.Params, "singleActive", true) // Default to true - only allow one VPN connection at a time
singleActive := true
if sa, ok := req.Params["singleActive"].(bool); ok {
singleActive = sa
}
if err := manager.ConnectVPN(uuidOrName, singleActive); err != nil { if err := manager.ConnectVPN(uuidOrName, singleActive); err != nil {
log.Warnf("handleConnectVPN: failed to connect: %v", err) log.Warnf("handleConnectVPN: failed to connect: %v", err)
@@ -404,15 +458,23 @@ func handleConnectVPN(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "VPN connection initiated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "VPN connection initiated"})
} }
func handleDisconnectVPN(conn net.Conn, req models.Request, manager *Manager) { func handleDisconnectVPN(conn net.Conn, req Request, manager *Manager) {
uuidOrName, ok := params.StringAlt(req.Params, "uuidOrName", "name", "uuid") uuidOrName, ok := req.Params["uuidOrName"].(string)
if !ok { if !ok {
log.Warnf("handleDisconnectVPN: missing uuidOrName/name/uuid parameter") name, nameOk := req.Params["name"].(string)
models.RespondError(conn, req.ID, "missing 'uuidOrName', 'name', or 'uuid' parameter") uuid, uuidOk := req.Params["uuid"].(string)
return if nameOk {
uuidOrName = name
} else if uuidOk {
uuidOrName = uuid
} else {
log.Warnf("handleDisconnectVPN: missing uuidOrName/name/uuid parameter")
models.RespondError(conn, req.ID, "missing 'uuidOrName', 'name', or 'uuid' parameter")
return
}
} }
if err := manager.DisconnectVPN(uuidOrName); err != nil { if err := manager.DisconnectVPN(uuidOrName); err != nil {
@@ -421,21 +483,27 @@ func handleDisconnectVPN(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "VPN disconnected"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "VPN disconnected"})
} }
func handleDisconnectAllVPN(conn net.Conn, req models.Request, manager *Manager) { func handleDisconnectAllVPN(conn net.Conn, req Request, manager *Manager) {
if err := manager.DisconnectAllVPN(); err != nil { if err := manager.DisconnectAllVPN(); err != nil {
log.Warnf("handleDisconnectAllVPN: failed: %v", err) log.Warnf("handleDisconnectAllVPN: failed: %v", err)
models.RespondError(conn, req.ID, fmt.Sprintf("failed to disconnect all VPNs: %v", err)) models.RespondError(conn, req.ID, fmt.Sprintf("failed to disconnect all VPNs: %v", err))
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "All VPNs disconnected"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "All VPNs disconnected"})
} }
func handleClearVPNCredentials(conn net.Conn, req models.Request, manager *Manager) { func handleClearVPNCredentials(conn net.Conn, req Request, manager *Manager) {
uuidOrName, ok := params.StringAlt(req.Params, "uuid", "name", "uuidOrName") uuidOrName, ok := req.Params["uuid"].(string)
if !ok {
uuidOrName, ok = req.Params["name"].(string)
}
if !ok {
uuidOrName, ok = req.Params["uuidOrName"].(string)
}
if !ok { if !ok {
log.Warnf("handleClearVPNCredentials: missing uuidOrName/name/uuid parameter") log.Warnf("handleClearVPNCredentials: missing uuidOrName/name/uuid parameter")
models.RespondError(conn, req.ID, "missing uuidOrName/name/uuid parameter") models.RespondError(conn, req.ID, "missing uuidOrName/name/uuid parameter")
@@ -448,19 +516,19 @@ func handleClearVPNCredentials(conn net.Conn, req models.Request, manager *Manag
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "VPN credentials cleared"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "VPN credentials cleared"})
} }
func handleSetWiFiAutoconnect(conn net.Conn, req models.Request, manager *Manager) { func handleSetWiFiAutoconnect(conn net.Conn, req Request, manager *Manager) {
ssid, err := params.String(req.Params, "ssid") ssid, ok := req.Params["ssid"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'ssid' parameter")
return return
} }
autoconnect, err := params.Bool(req.Params, "autoconnect") autoconnect, ok := req.Params["autoconnect"].(bool)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'autoconnect' parameter")
return return
} }
@@ -469,10 +537,10 @@ func handleSetWiFiAutoconnect(conn net.Conn, req models.Request, manager *Manage
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "autoconnect updated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "autoconnect updated"})
} }
func handleListVPNPlugins(conn net.Conn, req models.Request, manager *Manager) { func handleListVPNPlugins(conn net.Conn, req Request, manager *Manager) {
plugins, err := manager.ListVPNPlugins() plugins, err := manager.ListVPNPlugins()
if err != nil { if err != nil {
log.Warnf("handleListVPNPlugins: failed to list plugins: %v", err) log.Warnf("handleListVPNPlugins: failed to list plugins: %v", err)
@@ -483,14 +551,17 @@ func handleListVPNPlugins(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, plugins) models.Respond(conn, req.ID, plugins)
} }
func handleImportVPN(conn net.Conn, req models.Request, manager *Manager) { func handleImportVPN(conn net.Conn, req Request, manager *Manager) {
filePath, ok := params.StringAlt(req.Params, "file", "path") filePath, ok := req.Params["file"].(string)
if !ok {
filePath, ok = req.Params["path"].(string)
}
if !ok { if !ok {
models.RespondError(conn, req.ID, "missing 'file' or 'path' parameter") models.RespondError(conn, req.ID, "missing 'file' or 'path' parameter")
return return
} }
name := params.StringOpt(req.Params, "name", "") name, _ := req.Params["name"].(string)
result, err := manager.ImportVPN(filePath, name) result, err := manager.ImportVPN(filePath, name)
if err != nil { if err != nil {
@@ -502,8 +573,14 @@ func handleImportVPN(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, result) models.Respond(conn, req.ID, result)
} }
func handleGetVPNConfig(conn net.Conn, req models.Request, manager *Manager) { func handleGetVPNConfig(conn net.Conn, req Request, manager *Manager) {
uuidOrName, ok := params.StringAlt(req.Params, "uuid", "name", "uuidOrName") uuidOrName, ok := req.Params["uuid"].(string)
if !ok {
uuidOrName, ok = req.Params["name"].(string)
}
if !ok {
uuidOrName, ok = req.Params["uuidOrName"].(string)
}
if !ok { if !ok {
models.RespondError(conn, req.ID, "missing 'uuid', 'name', or 'uuidOrName' parameter") models.RespondError(conn, req.ID, "missing 'uuid', 'name', or 'uuidOrName' parameter")
return return
@@ -519,10 +596,10 @@ func handleGetVPNConfig(conn net.Conn, req models.Request, manager *Manager) {
models.Respond(conn, req.ID, config) models.Respond(conn, req.ID, config)
} }
func handleUpdateVPNConfig(conn net.Conn, req models.Request, manager *Manager) { func handleUpdateVPNConfig(conn net.Conn, req Request, manager *Manager) {
connUUID, err := params.String(req.Params, "uuid") connUUID, ok := req.Params["uuid"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing 'uuid' parameter")
return return
} }
@@ -549,11 +626,17 @@ func handleUpdateVPNConfig(conn net.Conn, req models.Request, manager *Manager)
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "VPN config updated"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "VPN config updated"})
} }
func handleDeleteVPN(conn net.Conn, req models.Request, manager *Manager) { func handleDeleteVPN(conn net.Conn, req Request, manager *Manager) {
uuidOrName, ok := params.StringAlt(req.Params, "uuid", "name", "uuidOrName") uuidOrName, ok := req.Params["uuid"].(string)
if !ok {
uuidOrName, ok = req.Params["name"].(string)
}
if !ok {
uuidOrName, ok = req.Params["uuidOrName"].(string)
}
if !ok { if !ok {
models.RespondError(conn, req.ID, "missing 'uuid', 'name', or 'uuidOrName' parameter") models.RespondError(conn, req.ID, "missing 'uuid', 'name', or 'uuidOrName' parameter")
return return
@@ -565,19 +648,23 @@ func handleDeleteVPN(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "VPN deleted"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "VPN deleted"})
} }
func handleSetVPNCredentials(conn net.Conn, req models.Request, manager *Manager) { func handleSetVPNCredentials(conn net.Conn, req Request, manager *Manager) {
connUUID, err := params.String(req.Params, "uuid") connUUID, ok := req.Params["uuid"].(string)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing 'uuid' parameter")
return return
} }
username := params.StringOpt(req.Params, "username", "") username, _ := req.Params["username"].(string)
password := params.StringOpt(req.Params, "password", "") password, _ := req.Params["password"].(string)
save := params.BoolOpt(req.Params, "save", true)
save := true
if saveParam, ok := req.Params["save"].(bool); ok {
save = saveParam
}
if err := manager.SetVPNCredentials(connUUID, username, password, save); err != nil { if err := manager.SetVPNCredentials(connUUID, username, password, save); err != nil {
log.Warnf("handleSetVPNCredentials: failed to set credentials: %v", err) log.Warnf("handleSetVPNCredentials: failed to set credentials: %v", err)
@@ -585,5 +672,5 @@ func handleSetVPNCredentials(conn net.Conn, req models.Request, manager *Manager
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "VPN credentials set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "VPN credentials set"})
} }

View File

@@ -53,10 +53,10 @@ func TestRespondError_Network(t *testing.T) {
func TestRespond_Network(t *testing.T) { func TestRespond_Network(t *testing.T) {
conn := newMockNetConn() conn := newMockNetConn()
result := models.SuccessResult{Success: true, Message: "test"} result := SuccessResult{Success: true, Message: "test"}
models.Respond(conn, 123, result) models.Respond(conn, 123, result)
var resp models.Response[models.SuccessResult] var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp) err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err) require.NoError(t, err)
@@ -77,7 +77,7 @@ func TestHandleGetState(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "network.getState"} req := Request{ID: 123, Method: "network.getState"}
handleGetState(conn, req, manager) handleGetState(conn, req, manager)
@@ -103,7 +103,7 @@ func TestHandleGetWiFiNetworks(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ID: 123, Method: "network.wifi.networks"} req := Request{ID: 123, Method: "network.wifi.networks"}
handleGetWiFiNetworks(conn, req, manager) handleGetWiFiNetworks(conn, req, manager)
@@ -125,7 +125,7 @@ func TestHandleConnectWiFi(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "network.wifi.connect", Method: "network.wifi.connect",
Params: map[string]any{}, Params: map[string]any{},
@@ -149,7 +149,7 @@ func TestHandleSetPreference(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "network.preference.set", Method: "network.preference.set",
Params: map[string]any{}, Params: map[string]any{},
@@ -173,7 +173,7 @@ func TestHandleGetNetworkInfo(t *testing.T) {
} }
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "network.info", Method: "network.info",
Params: map[string]any{}, Params: map[string]any{},
@@ -199,7 +199,7 @@ func TestHandleRequest(t *testing.T) {
t.Run("unknown method", func(t *testing.T) { t.Run("unknown method", func(t *testing.T) {
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "network.unknown", Method: "network.unknown",
} }
@@ -216,7 +216,7 @@ func TestHandleRequest(t *testing.T) {
t.Run("valid method - getState", func(t *testing.T) { t.Run("valid method - getState", func(t *testing.T) {
conn := newMockNetConn() conn := newMockNetConn()
req := models.Request{ req := Request{
ID: 123, ID: 123,
Method: "network.getState", Method: "network.getState",
} }

View File

@@ -1,113 +0,0 @@
package params
import "fmt"
func Get[T any](params map[string]any, key string) (T, error) {
val, ok := params[key].(T)
if !ok {
var zero T
return zero, fmt.Errorf("missing or invalid '%s' parameter", key)
}
return val, nil
}
func GetOpt[T any](params map[string]any, key string, def T) T {
if val, ok := params[key].(T); ok {
return val
}
return def
}
func String(params map[string]any, key string) (string, error) {
return Get[string](params, key)
}
func StringNonEmpty(params map[string]any, key string) (string, error) {
val, err := Get[string](params, key)
if err != nil || val == "" {
return "", fmt.Errorf("missing or invalid '%s' parameter", key)
}
return val, nil
}
func StringOpt(params map[string]any, key string, def string) string {
return GetOpt(params, key, def)
}
func Int(params map[string]any, key string) (int, error) {
val, err := Get[float64](params, key)
if err != nil {
return 0, err
}
return int(val), nil
}
func IntOpt(params map[string]any, key string, def int) int {
if val, ok := params[key].(float64); ok {
return int(val)
}
return def
}
func Float(params map[string]any, key string) (float64, error) {
return Get[float64](params, key)
}
func FloatOpt(params map[string]any, key string, def float64) float64 {
return GetOpt(params, key, def)
}
func Bool(params map[string]any, key string) (bool, error) {
return Get[bool](params, key)
}
func BoolOpt(params map[string]any, key string, def bool) bool {
return GetOpt(params, key, def)
}
func StringMap(params map[string]any, key string) (map[string]string, error) {
rawMap, err := Get[map[string]any](params, key)
if err != nil {
return nil, err
}
result := make(map[string]string)
for k, v := range rawMap {
if str, ok := v.(string); ok {
result[k] = str
}
}
return result, nil
}
func StringMapOpt(params map[string]any, key string) map[string]string {
rawMap, ok := params[key].(map[string]any)
if !ok {
return nil
}
result := make(map[string]string)
for k, v := range rawMap {
if str, ok := v.(string); ok {
result[k] = str
}
}
return result
}
func Any(params map[string]any, key string) (any, bool) {
val, ok := params[key]
return val, ok
}
func AnyMap(params map[string]any, key string) (map[string]any, bool) {
val, ok := params[key].(map[string]any)
return val, ok
}
func StringAlt(params map[string]any, keys ...string) (string, bool) {
for _, key := range keys {
if val, ok := params[key].(string); ok {
return val, true
}
}
return "", false
}

View File

@@ -1,154 +0,0 @@
package params
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGet(t *testing.T) {
p := map[string]any{"key": "value"}
val, err := Get[string](p, "key")
assert.NoError(t, err)
assert.Equal(t, "value", val)
_, err = Get[string](p, "missing")
assert.Error(t, err)
_, err = Get[int](p, "key")
assert.Error(t, err)
}
func TestGetOpt(t *testing.T) {
p := map[string]any{"key": "value"}
assert.Equal(t, "value", GetOpt(p, "key", "default"))
assert.Equal(t, "default", GetOpt(p, "missing", "default"))
}
func TestString(t *testing.T) {
p := map[string]any{"s": "hello", "n": 123}
val, err := String(p, "s")
assert.NoError(t, err)
assert.Equal(t, "hello", val)
_, err = String(p, "n")
assert.Error(t, err)
}
func TestStringNonEmpty(t *testing.T) {
p := map[string]any{"s": "hello", "empty": ""}
val, err := StringNonEmpty(p, "s")
assert.NoError(t, err)
assert.Equal(t, "hello", val)
_, err = StringNonEmpty(p, "empty")
assert.Error(t, err)
_, err = StringNonEmpty(p, "missing")
assert.Error(t, err)
}
func TestStringOpt(t *testing.T) {
p := map[string]any{"s": "hello"}
assert.Equal(t, "hello", StringOpt(p, "s", "default"))
assert.Equal(t, "default", StringOpt(p, "missing", "default"))
}
func TestInt(t *testing.T) {
p := map[string]any{"n": float64(42), "s": "str"}
val, err := Int(p, "n")
assert.NoError(t, err)
assert.Equal(t, 42, val)
_, err = Int(p, "s")
assert.Error(t, err)
}
func TestIntOpt(t *testing.T) {
p := map[string]any{"n": float64(42)}
assert.Equal(t, 42, IntOpt(p, "n", 0))
assert.Equal(t, 99, IntOpt(p, "missing", 99))
}
func TestFloat(t *testing.T) {
p := map[string]any{"f": 3.14, "s": "str"}
val, err := Float(p, "f")
assert.NoError(t, err)
assert.Equal(t, 3.14, val)
_, err = Float(p, "s")
assert.Error(t, err)
}
func TestFloatOpt(t *testing.T) {
p := map[string]any{"f": 3.14}
assert.Equal(t, 3.14, FloatOpt(p, "f", 0))
assert.Equal(t, 1.0, FloatOpt(p, "missing", 1.0))
}
func TestBool(t *testing.T) {
p := map[string]any{"b": true, "s": "str"}
val, err := Bool(p, "b")
assert.NoError(t, err)
assert.True(t, val)
_, err = Bool(p, "s")
assert.Error(t, err)
}
func TestBoolOpt(t *testing.T) {
p := map[string]any{"b": true}
assert.True(t, BoolOpt(p, "b", false))
assert.True(t, BoolOpt(p, "missing", true))
}
func TestStringMap(t *testing.T) {
p := map[string]any{
"m": map[string]any{"a": "1", "b": "2", "c": 3},
}
val, err := StringMap(p, "m")
assert.NoError(t, err)
assert.Equal(t, map[string]string{"a": "1", "b": "2"}, val)
_, err = StringMap(p, "missing")
assert.Error(t, err)
}
func TestStringMapOpt(t *testing.T) {
p := map[string]any{
"m": map[string]any{"a": "1"},
}
assert.Equal(t, map[string]string{"a": "1"}, StringMapOpt(p, "m"))
assert.Nil(t, StringMapOpt(p, "missing"))
}
func TestAny(t *testing.T) {
p := map[string]any{"k": 123}
val, ok := Any(p, "k")
assert.True(t, ok)
assert.Equal(t, 123, val)
_, ok = Any(p, "missing")
assert.False(t, ok)
}
func TestAnyMap(t *testing.T) {
inner := map[string]any{"nested": true}
p := map[string]any{"m": inner}
val, ok := AnyMap(p, "m")
assert.True(t, ok)
assert.Equal(t, inner, val)
_, ok = AnyMap(p, "missing")
assert.False(t, ok)
}
func TestStringAlt(t *testing.T) {
p := map[string]any{"b": "found"}
val, ok := StringAlt(p, "a", "b", "c")
assert.True(t, ok)
assert.Equal(t, "found", val)
_, ok = StringAlt(p, "x", "y")
assert.False(t, ok)
}

View File

@@ -15,50 +15,53 @@ func HandleUninstall(conn net.Conn, req models.Request) {
return return
} }
manager, err := plugins.NewManager()
if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to create manager: %v", err))
return
}
// First try to find in registry (by name or ID)
registry, err := plugins.NewRegistry() registry, err := plugins.NewRegistry()
if err != nil { if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to create registry: %v", err)) models.RespondError(conn, req.ID, fmt.Sprintf("failed to create registry: %v", err))
return return
} }
pluginList, _ := registry.List() pluginList, err := registry.List()
plugin := plugins.FindByIDOrName(name, pluginList) if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to list plugins: %v", err))
// If found in registry, use that
if plugin != nil {
installed, err := manager.IsInstalled(*plugin)
if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to check if plugin is installed: %v", err))
return
}
if !installed {
models.RespondError(conn, req.ID, fmt.Sprintf("plugin not installed: %s", name))
return
}
if err := manager.Uninstall(*plugin); err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to uninstall plugin: %v", err))
return
}
models.Respond(conn, req.ID, SuccessResult{
Success: true,
Message: fmt.Sprintf("plugin uninstalled: %s", plugin.Name),
})
return return
} }
// Not in registry - try to find and uninstall from installed plugins directly var plugin *plugins.Plugin
if err := manager.UninstallByIDOrName(name); err != nil { for _, p := range pluginList {
if p.Name == name {
plugin = &p
break
}
}
if plugin == nil {
models.RespondError(conn, req.ID, fmt.Sprintf("plugin not found: %s", name)) models.RespondError(conn, req.ID, fmt.Sprintf("plugin not found: %s", name))
return return
} }
manager, err := plugins.NewManager()
if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to create manager: %v", err))
return
}
installed, err := manager.IsInstalled(*plugin)
if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to check if plugin is installed: %v", err))
return
}
if !installed {
models.RespondError(conn, req.ID, fmt.Sprintf("plugin not installed: %s", name))
return
}
if err := manager.Uninstall(*plugin); err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to uninstall plugin: %v", err))
return
}
models.Respond(conn, req.ID, SuccessResult{ models.Respond(conn, req.ID, SuccessResult{
Success: true, Success: true,
Message: fmt.Sprintf("plugin uninstalled: %s", name), Message: fmt.Sprintf("plugin uninstalled: %s", name),

View File

@@ -15,48 +15,53 @@ func HandleUpdate(conn net.Conn, req models.Request) {
return return
} }
manager, err := plugins.NewManager()
if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to create manager: %v", err))
return
}
registry, err := plugins.NewRegistry() registry, err := plugins.NewRegistry()
if err != nil { if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to create registry: %v", err)) models.RespondError(conn, req.ID, fmt.Sprintf("failed to create registry: %v", err))
return return
} }
pluginList, _ := registry.List() pluginList, err := registry.List()
plugin := plugins.FindByIDOrName(name, pluginList) if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to list plugins: %v", err))
if plugin != nil {
installed, err := manager.IsInstalled(*plugin)
if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to check if plugin is installed: %v", err))
return
}
if !installed {
models.RespondError(conn, req.ID, fmt.Sprintf("plugin not installed: %s", name))
return
}
if err := manager.Update(*plugin); err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to update plugin: %v", err))
return
}
models.Respond(conn, req.ID, SuccessResult{
Success: true,
Message: fmt.Sprintf("plugin updated: %s", plugin.Name),
})
return return
} }
// Not in registry - try to update from installed plugins directly var plugin *plugins.Plugin
if err := manager.UpdateByIDOrName(name); err != nil { for _, p := range pluginList {
if p.Name == name {
plugin = &p
break
}
}
if plugin == nil {
models.RespondError(conn, req.ID, fmt.Sprintf("plugin not found: %s", name)) models.RespondError(conn, req.ID, fmt.Sprintf("plugin not found: %s", name))
return return
} }
manager, err := plugins.NewManager()
if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to create manager: %v", err))
return
}
installed, err := manager.IsInstalled(*plugin)
if err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to check if plugin is installed: %v", err))
return
}
if !installed {
models.RespondError(conn, req.ID, fmt.Sprintf("plugin not installed: %s", name))
return
}
if err := manager.Update(*plugin); err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to update plugin: %v", err))
return
}
models.Respond(conn, req.ID, SuccessResult{ models.Respond(conn, req.ID, SuccessResult{
Success: true, Success: true,
Message: fmt.Sprintf("plugin updated: %s", name), Message: fmt.Sprintf("plugin updated: %s", name),

View File

@@ -27,7 +27,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "network manager not initialized") models.RespondError(conn, req.ID, "network manager not initialized")
return return
} }
network.HandleRequest(conn, req, networkManager) netReq := network.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
network.HandleRequest(conn, netReq, networkManager)
return return
} }
@@ -41,7 +46,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "loginctl manager not initialized") models.RespondError(conn, req.ID, "loginctl manager not initialized")
return return
} }
loginctl.HandleRequest(conn, req, loginctlManager) loginReq := loginctl.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
loginctl.HandleRequest(conn, loginReq, loginctlManager)
return return
} }
@@ -50,7 +60,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "freedesktop manager not initialized") models.RespondError(conn, req.ID, "freedesktop manager not initialized")
return return
} }
freedesktop.HandleRequest(conn, req, freedesktopManager) freedeskReq := freedesktop.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
freedesktop.HandleRequest(conn, freedeskReq, freedesktopManager)
return return
} }
@@ -59,7 +74,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "wayland manager not initialized") models.RespondError(conn, req.ID, "wayland manager not initialized")
return return
} }
wayland.HandleRequest(conn, req, waylandManager) waylandReq := wayland.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
wayland.HandleRequest(conn, waylandReq, waylandManager)
return return
} }
@@ -68,7 +88,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "bluetooth manager not initialized") models.RespondError(conn, req.ID, "bluetooth manager not initialized")
return return
} }
bluez.HandleRequest(conn, req, bluezManager) bluezReq := bluez.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
bluez.HandleRequest(conn, bluezReq, bluezManager)
return return
} }
@@ -77,7 +102,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "apppicker manager not initialized") models.RespondError(conn, req.ID, "apppicker manager not initialized")
return return
} }
apppicker.HandleRequest(conn, req, appPickerManager) appPickerReq := apppicker.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
apppicker.HandleRequest(conn, appPickerReq, appPickerManager)
return return
} }
@@ -86,7 +116,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "CUPS manager not initialized") models.RespondError(conn, req.ID, "CUPS manager not initialized")
return return
} }
cups.HandleRequest(conn, req, cupsManager) cupsReq := cups.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
cups.HandleRequest(conn, cupsReq, cupsManager)
return return
} }
@@ -95,7 +130,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "dwl manager not initialized") models.RespondError(conn, req.ID, "dwl manager not initialized")
return return
} }
dwl.HandleRequest(conn, req, dwlManager) dwlReq := dwl.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
dwl.HandleRequest(conn, dwlReq, dwlManager)
return return
} }
@@ -104,7 +144,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "brightness manager not initialized") models.RespondError(conn, req.ID, "brightness manager not initialized")
return return
} }
brightness.HandleRequest(conn, req, brightnessManager) brightnessReq := brightness.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
brightness.HandleRequest(conn, brightnessReq, brightnessManager)
return return
} }
@@ -125,7 +170,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
return return
} }
} }
extworkspace.HandleRequest(conn, req, extWorkspaceManager) extWorkspaceReq := extworkspace.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
extworkspace.HandleRequest(conn, extWorkspaceReq, extWorkspaceManager)
return return
} }
@@ -134,7 +184,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "wlroutput manager not initialized") models.RespondError(conn, req.ID, "wlroutput manager not initialized")
return return
} }
wlroutput.HandleRequest(conn, req, wlrOutputManager) wlrOutputReq := wlroutput.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
wlroutput.HandleRequest(conn, wlrOutputReq, wlrOutputManager)
return return
} }
@@ -143,7 +198,12 @@ func RouteRequest(conn net.Conn, req models.Request) {
models.RespondError(conn, req.ID, "evdev manager not initialized") models.RespondError(conn, req.ID, "evdev manager not initialized")
return return
} }
evdev.HandleRequest(conn, req, evdevManager) evdevReq := evdev.Request{
ID: req.ID,
Method: req.Method,
Params: req.Params,
}
evdev.HandleRequest(conn, evdevReq, evdevManager)
return return
} }

View File

@@ -2,6 +2,8 @@ package wayland
import ( import (
"math" "math"
"github.com/AvengeMedia/DankMaterialShell/core/internal/utils"
) )
type GammaRamp struct { type GammaRamp struct {
@@ -10,126 +12,6 @@ type GammaRamp struct {
Blue []uint16 Blue []uint16
} }
type rgb struct {
r, g, b float64
}
type xyz struct {
x, y, z float64
}
func illuminantD(temp int) (float64, float64, bool) {
var x float64
switch {
case temp >= 2500 && temp <= 7000:
t := float64(temp)
x = 0.244063 + 0.09911e3/t + 2.9678e6/(t*t) - 4.6070e9/(t*t*t)
case temp > 7000 && temp <= 25000:
t := float64(temp)
x = 0.237040 + 0.24748e3/t + 1.9018e6/(t*t) - 2.0064e9/(t*t*t)
default:
return 0, 0, false
}
y := -3*(x*x) + 2.870*x - 0.275
return x, y, true
}
func planckianLocus(temp int) (float64, float64, bool) {
var x, y float64
switch {
case temp >= 1667 && temp <= 4000:
t := float64(temp)
x = -0.2661239e9/(t*t*t) - 0.2343589e6/(t*t) + 0.8776956e3/t + 0.179910
if temp <= 2222 {
y = -1.1064814*(x*x*x) - 1.34811020*(x*x) + 2.18555832*x - 0.20219683
} else {
y = -0.9549476*(x*x*x) - 1.37418593*(x*x) + 2.09137015*x - 0.16748867
}
case temp > 4000 && temp < 25000:
t := float64(temp)
x = -3.0258469e9/(t*t*t) + 2.1070379e6/(t*t) + 0.2226347e3/t + 0.240390
y = 3.0817580*(x*x*x) - 5.87338670*(x*x) + 3.75112997*x - 0.37001483
default:
return 0, 0, false
}
return x, y, true
}
func srgbGamma(value, gamma float64) float64 {
if value <= 0.0031308 {
return 12.92 * value
}
return math.Pow(1.055*value, 1.0/gamma) - 0.055
}
func clamp01(v float64) float64 {
switch {
case v > 1.0:
return 1.0
case v < 0.0:
return 0.0
default:
return v
}
}
func xyzToSRGB(c xyz) rgb {
return rgb{
r: srgbGamma(clamp01(3.2404542*c.x-1.5371385*c.y-0.4985314*c.z), 2.2),
g: srgbGamma(clamp01(-0.9692660*c.x+1.8760108*c.y+0.0415560*c.z), 2.2),
b: srgbGamma(clamp01(0.0556434*c.x-0.2040259*c.y+1.0572252*c.z), 2.2),
}
}
func normalizeRGB(c *rgb) {
maxw := math.Max(c.r, math.Max(c.g, c.b))
if maxw > 0 {
c.r /= maxw
c.g /= maxw
c.b /= maxw
}
}
func calcWhitepoint(temp int) rgb {
if temp == 6500 {
return rgb{r: 1.0, g: 1.0, b: 1.0}
}
var wp xyz
switch {
case temp >= 25000:
x, y, _ := illuminantD(25000)
wp.x = x
wp.y = y
case temp >= 4000:
x, y, _ := illuminantD(temp)
wp.x = x
wp.y = y
case temp >= 2500:
x1, y1, _ := illuminantD(temp)
x2, y2, _ := planckianLocus(temp)
factor := float64(4000-temp) / 1500.0
sineFactor := (math.Cos(math.Pi*factor) + 1.0) / 2.0
wp.x = x1*sineFactor + x2*(1.0-sineFactor)
wp.y = y1*sineFactor + y2*(1.0-sineFactor)
default:
t := temp
if t < 1667 {
t = 1667
}
x, y, _ := planckianLocus(t)
wp.x = x
wp.y = y
}
wp.z = 1.0 - wp.x - wp.y
wpRGB := xyzToSRGB(wp)
normalizeRGB(&wpRGB)
return wpRGB
}
func GenerateGammaRamp(size uint32, temp int, gamma float64) GammaRamp { func GenerateGammaRamp(size uint32, temp int, gamma float64) GammaRamp {
ramp := GammaRamp{ ramp := GammaRamp{
Red: make([]uint16, size), Red: make([]uint16, size),
@@ -137,13 +19,16 @@ func GenerateGammaRamp(size uint32, temp int, gamma float64) GammaRamp {
Blue: make([]uint16, size), Blue: make([]uint16, size),
} }
wp := calcWhitepoint(temp)
for i := uint32(0); i < size; i++ { for i := uint32(0); i < size; i++ {
val := float64(i) / float64(size-1) val := float64(i) / float64(size-1)
ramp.Red[i] = uint16(clamp01(math.Pow(val*wp.r, 1.0/gamma)) * 65535.0)
ramp.Green[i] = uint16(clamp01(math.Pow(val*wp.g, 1.0/gamma)) * 65535.0) valGamma := math.Pow(val, 1.0/gamma)
ramp.Blue[i] = uint16(clamp01(math.Pow(val*wp.b, 1.0/gamma)) * 65535.0)
r, g, b := temperatureToRGB(temp)
ramp.Red[i] = uint16(utils.Clamp(valGamma*r*65535.0, 0, 65535))
ramp.Green[i] = uint16(utils.Clamp(valGamma*g*65535.0, 0, 65535))
ramp.Blue[i] = uint16(utils.Clamp(valGamma*b*65535.0, 0, 65535))
} }
return ramp return ramp
@@ -165,3 +50,39 @@ func GenerateIdentityRamp(size uint32) GammaRamp {
return ramp return ramp
} }
func temperatureToRGB(temp int) (float64, float64, float64) {
tempK := float64(temp) / 100.0
var r, g, b float64
if tempK <= 66 {
r = 1.0
} else {
r = tempK - 60
r = 329.698727446 * math.Pow(r, -0.1332047592)
r = utils.Clamp(r, 0, 255) / 255.0
}
if tempK <= 66 {
g = tempK
g = 99.4708025861*math.Log(g) - 161.1195681661
g = utils.Clamp(g, 0, 255) / 255.0
} else {
g = tempK - 60
g = 288.1221695283 * math.Pow(g, -0.0755148492)
g = utils.Clamp(g, 0, 255) / 255.0
}
if tempK >= 66 {
b = 1.0
} else if tempK <= 19 {
b = 0.0
} else {
b = tempK - 10
b = 138.5177312231*math.Log(b) - 305.0447927307
b = utils.Clamp(b, 0, 255) / 255.0
}
return r, g, b
}

View File

@@ -54,7 +54,7 @@ func TestGenerateGammaRamp(t *testing.T) {
} }
} }
func TestCalcWhitepoint(t *testing.T) { func TestTemperatureToRGB(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
temp int temp int
@@ -67,32 +67,32 @@ func TestCalcWhitepoint(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
wp := calcWhitepoint(tt.temp) r, g, b := temperatureToRGB(tt.temp)
if wp.r < 0 || wp.r > 1 { if r < 0 || r > 1 {
t.Errorf("red out of range: %f", wp.r) t.Errorf("red out of range: %f", r)
} }
if wp.g < 0 || wp.g > 1 { if g < 0 || g > 1 {
t.Errorf("green out of range: %f", wp.g) t.Errorf("green out of range: %f", g)
} }
if wp.b < 0 || wp.b > 1 { if b < 0 || b > 1 {
t.Errorf("blue out of range: %f", wp.b) t.Errorf("blue out of range: %f", b)
} }
}) })
} }
} }
func TestWhitepointProgression(t *testing.T) { func TestTemperatureProgression(t *testing.T) {
temps := []int{3000, 4000, 5000, 6000, 6500} temps := []int{3000, 4000, 5000, 6000, 6500}
var prevBlue float64 var prevBlue float64
for i, temp := range temps { for i, temp := range temps {
wp := calcWhitepoint(temp) _, _, b := temperatureToRGB(temp)
if i > 0 && wp.b < prevBlue { if i > 0 && b < prevBlue {
t.Errorf("blue should increase with temperature, %d->%d: %f->%f", t.Errorf("blue should increase with temperature, %d->%d: %f->%f",
temps[i-1], temp, prevBlue, wp.b) temps[i-1], temp, prevBlue, b)
} }
prevBlue = wp.b prevBlue = b
} }
} }

View File

@@ -7,10 +7,20 @@ import (
"time" "time"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/params"
) )
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) { type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]any `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
if manager == nil { if manager == nil {
models.RespondError(conn, req.ID, "wayland manager not initialized") models.RespondError(conn, req.ID, "wayland manager not initialized")
return return
@@ -38,27 +48,26 @@ func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleGetState(conn net.Conn, req models.Request, manager *Manager) { func handleGetState(conn net.Conn, req Request, manager *Manager) {
models.Respond(conn, req.ID, manager.GetState()) state := manager.GetState()
models.Respond(conn, req.ID, state)
} }
func handleSetTemperature(conn net.Conn, req models.Request, manager *Manager) { func handleSetTemperature(conn net.Conn, req Request, manager *Manager) {
var lowTemp, highTemp int var lowTemp, highTemp int
if temp, ok := req.Params["temp"].(float64); ok { if temp, ok := req.Params["temp"].(float64); ok {
lowTemp = int(temp) lowTemp = int(temp)
highTemp = int(temp) highTemp = int(temp)
} else { } else {
low, err := params.Float(req.Params, "low") low, okLow := req.Params["low"].(float64)
if err != nil { high, okHigh := req.Params["high"].(float64)
models.RespondError(conn, req.ID, "missing temperature parameters (provide 'temp' or both 'low' and 'high')")
return if !okLow || !okHigh {
}
high, err := params.Float(req.Params, "high")
if err != nil {
models.RespondError(conn, req.ID, "missing temperature parameters (provide 'temp' or both 'low' and 'high')") models.RespondError(conn, req.ID, "missing temperature parameters (provide 'temp' or both 'low' and 'high')")
return return
} }
lowTemp = int(low) lowTemp = int(low)
highTemp = int(high) highTemp = int(high)
} }
@@ -68,19 +77,19 @@ func handleSetTemperature(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "temperature set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "temperature set"})
} }
func handleSetLocation(conn net.Conn, req models.Request, manager *Manager) { func handleSetLocation(conn net.Conn, req Request, manager *Manager) {
lat, err := params.Float(req.Params, "latitude") lat, ok := req.Params["latitude"].(float64)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'latitude' parameter")
return return
} }
lon, err := params.Float(req.Params, "longitude") lon, ok := req.Params["longitude"].(float64)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'longitude' parameter")
return return
} }
@@ -89,30 +98,30 @@ func handleSetLocation(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "location set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "location set"})
} }
func handleSetManualTimes(conn net.Conn, req models.Request, manager *Manager) { func handleSetManualTimes(conn net.Conn, req Request, manager *Manager) {
sunriseParam := req.Params["sunrise"] sunriseParam := req.Params["sunrise"]
sunsetParam := req.Params["sunset"] sunsetParam := req.Params["sunset"]
if sunriseParam == nil || sunsetParam == nil { if sunriseParam == nil || sunsetParam == nil {
manager.ClearManualTimes() manager.ClearManualTimes()
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "manual times cleared"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "manual times cleared"})
return return
} }
sunriseStr, ok := sunriseParam.(string) sunriseStr, ok := sunriseParam.(string)
if !ok || sunriseStr == "" { if !ok || sunriseStr == "" {
manager.ClearManualTimes() manager.ClearManualTimes()
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "manual times cleared"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "manual times cleared"})
return return
} }
sunsetStr, ok := sunsetParam.(string) sunsetStr, ok := sunsetParam.(string)
if !ok || sunsetStr == "" { if !ok || sunsetStr == "" {
manager.ClearManualTimes() manager.ClearManualTimes()
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "manual times cleared"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "manual times cleared"})
return return
} }
@@ -133,24 +142,24 @@ func handleSetManualTimes(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "manual times set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "manual times set"})
} }
func handleSetUseIPLocation(conn net.Conn, req models.Request, manager *Manager) { func handleSetUseIPLocation(conn net.Conn, req Request, manager *Manager) {
use, err := params.Bool(req.Params, "use") use, ok := req.Params["use"].(bool)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'use' parameter")
return return
} }
manager.SetUseIPLocation(use) manager.SetUseIPLocation(use)
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "IP location preference set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "IP location preference set"})
} }
func handleSetGamma(conn net.Conn, req models.Request, manager *Manager) { func handleSetGamma(conn net.Conn, req Request, manager *Manager) {
gamma, err := params.Float(req.Params, "gamma") gamma, ok := req.Params["gamma"].(float64)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'gamma' parameter")
return return
} }
@@ -159,21 +168,21 @@ func handleSetGamma(conn net.Conn, req models.Request, manager *Manager) {
return return
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "gamma set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "gamma set"})
} }
func handleSetEnabled(conn net.Conn, req models.Request, manager *Manager) { func handleSetEnabled(conn net.Conn, req Request, manager *Manager) {
enabled, err := params.Bool(req.Params, "enabled") enabled, ok := req.Params["enabled"].(bool)
if err != nil { if !ok {
models.RespondError(conn, req.ID, err.Error()) models.RespondError(conn, req.ID, "missing or invalid 'enabled' parameter")
return return
} }
manager.SetEnabled(enabled) manager.SetEnabled(enabled)
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: "enabled state set"}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "enabled state set"})
} }
func handleSubscribe(conn net.Conn, req models.Request, manager *Manager) { func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn) clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID) stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID) defer manager.Unsubscribe(clientID)

File diff suppressed because it is too large Load Diff

View File

@@ -6,117 +6,81 @@ import (
) )
const ( const (
degToRad = math.Pi / 180.0 degToRad = math.Pi / 180.0
radToDeg = 180.0 / math.Pi radToDeg = 180.0 / math.Pi
solarNoon = 12.0
sunriseAngle = -0.833
) )
type SunCondition int
const (
SunNormal SunCondition = iota
SunMidnightSun
SunPolarNight
)
type SunTimes struct {
Dawn time.Time
Sunrise time.Time
Sunset time.Time
Night time.Time
}
func daysInYear(year int) int {
if (year%4 == 0 && year%100 != 0) || year%400 == 0 {
return 366
}
return 365
}
func dateOrbitAngle(t time.Time) float64 {
return 2 * math.Pi / float64(daysInYear(t.Year())) * float64(t.YearDay()-1)
}
func equationOfTime(orbitAngle float64) float64 {
return 4 * (0.000075 +
0.001868*math.Cos(orbitAngle) -
0.032077*math.Sin(orbitAngle) -
0.014615*math.Cos(2*orbitAngle) -
0.040849*math.Sin(2*orbitAngle))
}
func sunDeclination(orbitAngle float64) float64 {
return 0.006918 -
0.399912*math.Cos(orbitAngle) +
0.070257*math.Sin(orbitAngle) -
0.006758*math.Cos(2*orbitAngle) +
0.000907*math.Sin(2*orbitAngle) -
0.002697*math.Cos(3*orbitAngle) +
0.00148*math.Sin(3*orbitAngle)
}
func sunHourAngle(latRad, declination, targetSunRad float64) float64 {
return math.Acos(math.Cos(targetSunRad)/
math.Cos(latRad)*math.Cos(declination) -
math.Tan(latRad)*math.Tan(declination))
}
func hourAngleToSeconds(hourAngle, eqtime float64) float64 {
return radToDeg * (4.0*math.Pi - 4*hourAngle - eqtime) * 60
}
func sunCondition(latRad, declination float64) SunCondition {
signLat := latRad >= 0
signDecl := declination >= 0
if signLat == signDecl {
return SunMidnightSun
}
return SunPolarNight
}
func CalculateSunTimesWithTwilight(lat, lon float64, date time.Time, elevTwilight, elevDaylight float64) (SunTimes, SunCondition) {
latRad := lat * degToRad
elevTwilightRad := (90.833 - elevTwilight) * degToRad
elevDaylightRad := (90.833 - elevDaylight) * degToRad
utc := date.UTC()
orbitAngle := dateOrbitAngle(utc)
decl := sunDeclination(orbitAngle)
eqtime := equationOfTime(orbitAngle)
haTwilight := sunHourAngle(latRad, decl, elevTwilightRad)
haDaylight := sunHourAngle(latRad, decl, elevDaylightRad)
if math.IsNaN(haTwilight) || math.IsNaN(haDaylight) {
cond := sunCondition(latRad, decl)
return SunTimes{}, cond
}
dayStart := time.Date(utc.Year(), utc.Month(), utc.Day(), 0, 0, 0, 0, time.UTC)
lonOffset := time.Duration(-lon*4) * time.Minute
dawnSecs := hourAngleToSeconds(math.Abs(haTwilight), eqtime)
sunriseSecs := hourAngleToSeconds(math.Abs(haDaylight), eqtime)
sunsetSecs := hourAngleToSeconds(-math.Abs(haDaylight), eqtime)
nightSecs := hourAngleToSeconds(-math.Abs(haTwilight), eqtime)
return SunTimes{
Dawn: dayStart.Add(time.Duration(dawnSecs)*time.Second + lonOffset).In(date.Location()),
Sunrise: dayStart.Add(time.Duration(sunriseSecs)*time.Second + lonOffset).In(date.Location()),
Sunset: dayStart.Add(time.Duration(sunsetSecs)*time.Second + lonOffset).In(date.Location()),
Night: dayStart.Add(time.Duration(nightSecs)*time.Second + lonOffset).In(date.Location()),
}, SunNormal
}
func CalculateSunTimes(lat, lon float64, date time.Time) SunTimes { func CalculateSunTimes(lat, lon float64, date time.Time) SunTimes {
times, cond := CalculateSunTimesWithTwilight(lat, lon, date, -6.0, 3.0) utcDate := date.UTC()
switch cond { year, month, day := utcDate.Date()
case SunMidnightSun: loc := date.Location()
dayStart := time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, date.Location())
dayEnd := dayStart.Add(24*time.Hour - time.Second) dayOfYear := utcDate.YearDay()
return SunTimes{Dawn: dayStart, Sunrise: dayStart, Sunset: dayEnd, Night: dayEnd}
case SunPolarNight: gamma := 2 * math.Pi / 365 * float64(dayOfYear-1)
dayStart := time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, date.Location())
return SunTimes{Dawn: dayStart, Sunrise: dayStart, Sunset: dayStart, Night: dayStart} eqTime := 229.18 * (0.000075 +
0.001868*math.Cos(gamma) -
0.032077*math.Sin(gamma) -
0.014615*math.Cos(2*gamma) -
0.040849*math.Sin(2*gamma))
decl := 0.006918 -
0.399912*math.Cos(gamma) +
0.070257*math.Sin(gamma) -
0.006758*math.Cos(2*gamma) +
0.000907*math.Sin(2*gamma) -
0.002697*math.Cos(3*gamma) +
0.00148*math.Sin(3*gamma)
latRad := lat * degToRad
cosHourAngle := (math.Sin(sunriseAngle*degToRad) -
math.Sin(latRad)*math.Sin(decl)) /
(math.Cos(latRad) * math.Cos(decl))
if cosHourAngle > 1 {
return SunTimes{
Sunrise: time.Date(year, month, day, 0, 0, 0, 0, time.UTC).In(loc),
Sunset: time.Date(year, month, day, 0, 0, 0, 0, time.UTC).In(loc),
}
}
if cosHourAngle < -1 {
return SunTimes{
Sunrise: time.Date(year, month, day, 0, 0, 0, 0, time.UTC).In(loc),
Sunset: time.Date(year, month, day, 23, 59, 59, 0, time.UTC).In(loc),
}
}
hourAngle := math.Acos(cosHourAngle) * radToDeg
sunriseTime := solarNoon - hourAngle/15.0 - lon/15.0 - eqTime/60.0
sunsetTime := solarNoon + hourAngle/15.0 - lon/15.0 - eqTime/60.0
sunrise := timeOfDayToTime(sunriseTime, year, month, day, time.UTC).In(loc)
sunset := timeOfDayToTime(sunsetTime, year, month, day, time.UTC).In(loc)
return SunTimes{
Sunrise: sunrise,
Sunset: sunset,
} }
return times }
func timeOfDayToTime(hours float64, year int, month time.Month, day int, loc *time.Location) time.Time {
h := int(hours)
m := int((hours - float64(h)) * 60)
s := int(((hours-float64(h))*60 - float64(m)) * 60)
if h < 0 {
h += 24
day--
}
if h >= 24 {
h -= 24
day++
}
return time.Date(year, month, day, h, m, s, 0, loc)
} }

View File

@@ -340,47 +340,38 @@ func TestCalculateNextTransition(t *testing.T) {
} }
} }
func TestSunTimesWithTwilight(t *testing.T) { func TestTimeOfDayToTime(t *testing.T) {
lat := 40.7128
lon := -74.0060
date := time.Date(2024, 6, 21, 12, 0, 0, 0, time.Local)
times, cond := CalculateSunTimesWithTwilight(lat, lon, date, -6.0, 3.0)
if cond != SunNormal {
t.Errorf("expected SunNormal, got %v", cond)
}
if !times.Dawn.Before(times.Sunrise) {
t.Error("dawn should be before sunrise")
}
if !times.Sunrise.Before(times.Sunset) {
t.Error("sunrise should be before sunset")
}
if !times.Sunset.Before(times.Night) {
t.Error("sunset should be before night")
}
}
func TestSunConditions(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
lat float64 hours float64
date time.Time expected time.Time
expected SunCondition
}{ }{
{ {
name: "normal_conditions", name: "noon",
lat: 40.0, hours: 12.0,
date: time.Date(2024, 6, 21, 12, 0, 0, 0, time.UTC), expected: time.Date(2024, 6, 21, 12, 0, 0, 0, time.Local),
expected: SunNormal, },
{
name: "half_past",
hours: 12.5,
expected: time.Date(2024, 6, 21, 12, 30, 0, 0, time.Local),
},
{
name: "early_morning",
hours: 6.25,
expected: time.Date(2024, 6, 21, 6, 15, 0, 0, time.Local),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, cond := CalculateSunTimesWithTwilight(tt.lat, 0, tt.date, -6.0, 3.0) result := timeOfDayToTime(tt.hours, 2024, 6, 21, time.Local)
if cond != tt.expected {
t.Errorf("expected condition %v, got %v", tt.expected, cond) if result.Hour() != tt.expected.Hour() {
t.Errorf("hour = %d, want %d", result.Hour(), tt.expected.Hour())
}
if result.Minute() != tt.expected.Minute() {
t.Errorf("minute = %d, want %d", result.Minute(), tt.expected.Minute())
} }
}) })
} }

View File

@@ -11,28 +11,18 @@ import (
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
) )
type GammaState int
const (
StateNormal GammaState = iota
StateTransition
StateStatic
)
type Config struct { type Config struct {
Outputs []string Outputs []string
LowTemp int LowTemp int
HighTemp int HighTemp int
Latitude *float64 Latitude *float64
Longitude *float64 Longitude *float64
UseIPLocation bool UseIPLocation bool
ManualSunrise *time.Time ManualSunrise *time.Time
ManualSunset *time.Time ManualSunset *time.Time
ManualDuration *time.Duration ManualDuration *time.Duration
Gamma float64 Gamma float64
Enabled bool Enabled bool
ElevationTwilight float64
ElevationDaylight float64
} }
type State struct { type State struct {
@@ -41,24 +31,13 @@ type State struct {
NextTransition time.Time `json:"nextTransition"` NextTransition time.Time `json:"nextTransition"`
SunriseTime time.Time `json:"sunriseTime"` SunriseTime time.Time `json:"sunriseTime"`
SunsetTime time.Time `json:"sunsetTime"` SunsetTime time.Time `json:"sunsetTime"`
DawnTime time.Time `json:"dawnTime"`
NightTime time.Time `json:"nightTime"`
IsDay bool `json:"isDay"` IsDay bool `json:"isDay"`
SunPosition float64 `json:"sunPosition"`
} }
type cmd struct { type cmd struct {
fn func() fn func()
} }
type sunSchedule struct {
times SunTimes
condition SunCondition
dawnStepTime time.Duration
nightStepTime time.Duration
calcDay time.Time
}
type Manager struct { type Manager struct {
config Config config Config
configMutex sync.RWMutex configMutex sync.RWMutex
@@ -81,9 +60,10 @@ type Manager struct {
updateTrigger chan struct{} updateTrigger chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
schedule sunSchedule currentTemp int
scheduleMutex sync.RWMutex targetTemp int
gammaState GammaState transitionMutex sync.RWMutex
transitionChan chan int
cachedIPLat *float64 cachedIPLat *float64
cachedIPLon *float64 cachedIPLon *float64
@@ -100,6 +80,7 @@ type Manager struct {
type outputState struct { type outputState struct {
id uint32 id uint32
name string
registryName uint32 registryName uint32
output *wlclient.Output output *wlclient.Output
gammaControl any gammaControl any
@@ -110,15 +91,18 @@ type outputState struct {
lastFailTime time.Time lastFailTime time.Time
} }
type SunTimes struct {
Sunrise time.Time
Sunset time.Time
}
func DefaultConfig() Config { func DefaultConfig() Config {
return Config{ return Config{
Outputs: []string{}, Outputs: []string{},
LowTemp: 4000, LowTemp: 4000,
HighTemp: 6500, HighTemp: 6500,
Gamma: 1.0, Gamma: 1.0,
Enabled: false, Enabled: false,
ElevationTwilight: -6.0,
ElevationDaylight: 3.0,
} }
} }
@@ -156,7 +140,8 @@ func (m *Manager) GetState() State {
if m.state == nil { if m.state == nil {
return State{} return State{}
} }
return *m.state stateCopy := *m.state
return stateCopy
} }
func (m *Manager) Subscribe(id string) chan State { func (m *Manager) Subscribe(id string) chan State {
@@ -200,8 +185,5 @@ func stateChanged(old, new *State) bool {
if old.Config.Enabled != new.Config.Enabled { if old.Config.Enabled != new.Config.Enabled {
return true return true
} }
if old.SunPosition != new.SunPosition {
return true
}
return false return false
} }

View File

@@ -11,6 +11,17 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/models" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/models"
) )
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]any `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
type HeadConfig struct { type HeadConfig struct {
Name string `json:"name"` Name string `json:"name"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
@@ -31,7 +42,7 @@ type ConfigurationRequest struct {
Test bool `json:"test"` Test bool `json:"test"`
} }
func HandleRequest(conn net.Conn, req models.Request, manager *Manager) { func HandleRequest(conn net.Conn, req Request, manager *Manager) {
if manager == nil { if manager == nil {
models.RespondError(conn, req.ID, "wlroutput manager not initialized") models.RespondError(conn, req.ID, "wlroutput manager not initialized")
return return
@@ -51,11 +62,12 @@ func HandleRequest(conn net.Conn, req models.Request, manager *Manager) {
} }
} }
func handleGetState(conn net.Conn, req models.Request, manager *Manager) { func handleGetState(conn net.Conn, req Request, manager *Manager) {
models.Respond(conn, req.ID, manager.GetState()) state := manager.GetState()
models.Respond(conn, req.ID, state)
} }
func handleApplyConfiguration(conn net.Conn, req models.Request, manager *Manager, test bool) { func handleApplyConfiguration(conn net.Conn, req Request, manager *Manager, test bool) {
headsParam, ok := req.Params["heads"] headsParam, ok := req.Params["heads"]
if !ok { if !ok {
models.RespondError(conn, req.ID, "missing 'heads' parameter") models.RespondError(conn, req.ID, "missing 'heads' parameter")
@@ -83,10 +95,10 @@ func handleApplyConfiguration(conn net.Conn, req models.Request, manager *Manage
if test { if test {
msg = "configuration test succeeded" msg = "configuration test succeeded"
} }
models.Respond(conn, req.ID, models.SuccessResult{Success: true, Message: msg}) models.Respond(conn, req.ID, SuccessResult{Success: true, Message: msg})
} }
func handleSubscribe(conn net.Conn, req models.Request, manager *Manager) { func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn) clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID) stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID) defer manager.Unsubscribe(clientID)

View File

@@ -4,8 +4,6 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/AvengeMedia/DankMaterialShell/core/internal/deps"
"github.com/AvengeMedia/DankMaterialShell/core/internal/distros"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
@@ -120,59 +118,6 @@ func (m Model) viewInstallingPackages() string {
return b.String() return b.String()
} }
func dmsPackageName(distroID string, dependencies []deps.Dependency) string {
config, ok := distros.Registry[distroID]
if !ok {
return "dms"
}
var isGit bool
for _, dep := range dependencies {
if dep.Name == "dms (DankMaterialShell)" {
isGit = dep.Variant == deps.VariantGit
break
}
}
switch config.Family {
case distros.FamilyArch:
if isGit {
return "dms-shell-git"
}
return "dms-shell-bin"
case distros.FamilyFedora, distros.FamilyUbuntu, distros.FamilyDebian, distros.FamilySUSE:
if isGit {
return "dms-git"
}
return "dms"
default:
return "dms"
}
}
func uninstallCommand(distroID string, dependencies []deps.Dependency) string {
config, ok := distros.Registry[distroID]
if !ok {
return ""
}
if config.Family == distros.FamilyGentoo {
return "rm -rf ~/.config/quickshell/dms && sudo rm /usr/local/bin/dms"
}
pkg := dmsPackageName(distroID, dependencies)
switch config.Family {
case distros.FamilyArch:
return "sudo pacman -Rs " + pkg
case distros.FamilyFedora:
return "sudo dnf remove " + pkg
case distros.FamilyUbuntu, distros.FamilyDebian:
return "sudo apt remove " + pkg
case distros.FamilySUSE:
return "sudo zypper remove " + pkg
default:
return ""
}
}
func (m Model) viewInstallComplete() string { func (m Model) viewInstallComplete() string {
var b strings.Builder var b strings.Builder
@@ -187,6 +132,7 @@ func (m Model) viewInstallComplete() string {
b.WriteString(success) b.WriteString(success)
b.WriteString("\n\n") b.WriteString("\n\n")
// Show what was accomplished
accomplishments := []string{ accomplishments := []string{
"• Window manager and dependencies installed", "• Window manager and dependencies installed",
"• Terminal and development tools configured", "• Terminal and development tools configured",
@@ -200,26 +146,8 @@ func (m Model) viewInstallComplete() string {
} }
b.WriteString("\n") b.WriteString("\n")
info := m.styles.Normal.Render("Your system is ready! Log out and log back in to start using\nyour new desktop environment.\nIf you do not have a greeter, login with \"niri-session\" or \"Hyprland\"") info := m.styles.Normal.Render("Your system is ready! Log out and log back in to start using\nyour new desktop environment.\nIf you do not have a greeter, login with \"niri-session\" or \"Hyprland\" \n\nPress Enter to exit.")
b.WriteString(info) b.WriteString(info)
b.WriteString("\n\n")
theme := TerminalTheme()
cmdStyle := lipgloss.NewStyle().Foreground(lipgloss.Color(theme.Accent))
labelStyle := lipgloss.NewStyle().Foreground(lipgloss.Color(theme.Subtle))
b.WriteString(labelStyle.Render("Troubleshooting:") + "\n")
b.WriteString(labelStyle.Render(" Disable autostart: ") + cmdStyle.Render("systemctl --user disable dms") + "\n")
b.WriteString(labelStyle.Render(" View logs: ") + cmdStyle.Render("journalctl --user -u dms") + "\n")
if m.osInfo != nil {
if cmd := uninstallCommand(m.osInfo.Distribution.ID, m.dependencies); cmd != "" {
b.WriteString(labelStyle.Render(" Uninstall: ") + cmdStyle.Render(cmd) + "\n")
}
}
b.WriteString("\n")
b.WriteString(m.styles.Normal.Render("Press Enter to exit."))
if m.logFilePath != "" { if m.logFilePath != "" {
b.WriteString("\n\n") b.WriteString("\n\n")

View File

@@ -40,7 +40,7 @@ func (m Model) viewWelcome() string {
subtitle := lipgloss.NewStyle(). subtitle := lipgloss.NewStyle().
Foreground(lipgloss.Color(theme.Subtle)). Foreground(lipgloss.Color(theme.Subtle)).
Italic(true). Italic(true).
Render("Quickstart for a Dank Desktop") Render("Quickstart for a Dank Desktop")
b.WriteString(decorator) b.WriteString(decorator)
b.WriteString("\n") b.WriteString("\n")

View File

@@ -1,8 +0,0 @@
package utils
import "os/exec"
func CommandExists(cmd string) bool {
_, err := exec.LookPath(cmd)
return err == nil
}

View File

@@ -1,52 +0,0 @@
package utils
import (
"os"
"path/filepath"
"strings"
)
func ExpandPath(path string) (string, error) {
expanded := os.ExpandEnv(path)
expanded = filepath.Clean(expanded)
if strings.HasPrefix(expanded, "~") {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
expanded = filepath.Join(home, expanded[1:])
}
return expanded, nil
}
func XDGConfigHome() string {
if configHome := os.Getenv("XDG_CONFIG_HOME"); configHome != "" {
return configHome
}
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, ".config")
}
return filepath.Join(os.TempDir(), ".config")
}
func XDGCacheHome() string {
if cacheHome := os.Getenv("XDG_CACHE_HOME"); cacheHome != "" {
return cacheHome
}
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, ".cache")
}
return filepath.Join(os.TempDir(), ".cache")
}
func XDGDataHome() string {
if dataHome := os.Getenv("XDG_DATA_HOME"); dataHome != "" {
return dataHome
}
if home, err := os.UserHomeDir(); err == nil {
return filepath.Join(home, ".local", "share")
}
return filepath.Join(os.TempDir(), ".local", "share")
}

View File

@@ -1,106 +0,0 @@
package utils
import (
"os"
"path/filepath"
"testing"
)
func TestExpandPathTilde(t *testing.T) {
home, err := os.UserHomeDir()
if err != nil {
t.Skip("no home directory")
}
result, err := ExpandPath("~/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := filepath.Join(home, "test")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
}
func TestExpandPathEnvVar(t *testing.T) {
t.Setenv("TEST_PATH_VAR", "/custom/path")
result, err := ExpandPath("$TEST_PATH_VAR/subdir")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "/custom/path/subdir" {
t.Errorf("expected /custom/path/subdir, got %s", result)
}
}
func TestExpandPathAbsolute(t *testing.T) {
result, err := ExpandPath("/absolute/path")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "/absolute/path" {
t.Errorf("expected /absolute/path, got %s", result)
}
}
func TestXDGConfigHomeDefault(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", "")
home, err := os.UserHomeDir()
if err != nil {
t.Skip("no home directory")
}
result := XDGConfigHome()
expected := filepath.Join(home, ".config")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
}
func TestXDGConfigHomeCustom(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", "/custom/config")
result := XDGConfigHome()
if result != "/custom/config" {
t.Errorf("expected /custom/config, got %s", result)
}
}
func TestXDGCacheHomeDefault(t *testing.T) {
t.Setenv("XDG_CACHE_HOME", "")
home, err := os.UserHomeDir()
if err != nil {
t.Skip("no home directory")
}
result := XDGCacheHome()
expected := filepath.Join(home, ".cache")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
}
func TestXDGCacheHomeCustom(t *testing.T) {
t.Setenv("XDG_CACHE_HOME", "/custom/cache")
result := XDGCacheHome()
if result != "/custom/cache" {
t.Errorf("expected /custom/cache, got %s", result)
}
}
func TestXDGDataHomeDefault(t *testing.T) {
t.Setenv("XDG_DATA_HOME", "")
home, err := os.UserHomeDir()
if err != nil {
t.Skip("no home directory")
}
result := XDGDataHome()
expected := filepath.Join(home, ".local", "share")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
}
func TestXDGDataHomeCustom(t *testing.T) {
t.Setenv("XDG_DATA_HOME", "/custom/data")
result := XDGDataHome()
if result != "/custom/data" {
t.Errorf("expected /custom/data, got %s", result)
}
}

View File

@@ -1,56 +0,0 @@
package utils
func Filter[T any](items []T, predicate func(T) bool) []T {
var result []T
for _, item := range items {
if predicate(item) {
result = append(result, item)
}
}
return result
}
func Find[T any](items []T, predicate func(T) bool) (T, bool) {
for _, item := range items {
if predicate(item) {
return item, true
}
}
var zero T
return zero, false
}
func Map[T, U any](items []T, transform func(T) U) []U {
result := make([]U, len(items))
for i, item := range items {
result[i] = transform(item)
}
return result
}
func Contains[T comparable](items []T, target T) bool {
for _, item := range items {
if item == target {
return true
}
}
return false
}
func Any[T any](items []T, predicate func(T) bool) bool {
for _, item := range items {
if predicate(item) {
return true
}
}
return false
}
func All[T any](items []T, predicate func(T) bool) bool {
for _, item := range items {
if !predicate(item) {
return false
}
}
return true
}

View File

@@ -1,72 +0,0 @@
package utils
import (
"testing"
)
func TestFilter(t *testing.T) {
nums := []int{1, 2, 3, 4, 5}
evens := Filter(nums, func(n int) bool { return n%2 == 0 })
if len(evens) != 2 || evens[0] != 2 || evens[1] != 4 {
t.Errorf("expected [2, 4], got %v", evens)
}
}
func TestFilterEmpty(t *testing.T) {
result := Filter([]int{1, 2, 3}, func(n int) bool { return n > 10 })
if len(result) != 0 {
t.Errorf("expected empty slice, got %v", result)
}
}
func TestFind(t *testing.T) {
nums := []int{1, 2, 3, 4, 5}
val, found := Find(nums, func(n int) bool { return n == 3 })
if !found || val != 3 {
t.Errorf("expected 3, got %v (found=%v)", val, found)
}
}
func TestFindNotFound(t *testing.T) {
nums := []int{1, 2, 3}
val, found := Find(nums, func(n int) bool { return n == 99 })
if found || val != 0 {
t.Errorf("expected zero value not found, got %v (found=%v)", val, found)
}
}
func TestMap(t *testing.T) {
nums := []int{1, 2, 3}
doubled := Map(nums, func(n int) int { return n * 2 })
if len(doubled) != 3 || doubled[0] != 2 || doubled[1] != 4 || doubled[2] != 6 {
t.Errorf("expected [2, 4, 6], got %v", doubled)
}
}
func TestMapTypeConversion(t *testing.T) {
nums := []int{1, 2, 3}
strs := Map(nums, func(n int) string { return string(rune('a' + n - 1)) })
if strs[0] != "a" || strs[1] != "b" || strs[2] != "c" {
t.Errorf("expected [a, b, c], got %v", strs)
}
}
func TestContains(t *testing.T) {
nums := []int{1, 2, 3}
if !Contains(nums, 2) {
t.Error("expected to contain 2")
}
if Contains(nums, 99) {
t.Error("expected not to contain 99")
}
}
func TestAny(t *testing.T) {
nums := []int{1, 2, 3, 4, 5}
if !Any(nums, func(n int) bool { return n > 4 }) {
t.Error("expected any > 4")
}
if Any(nums, func(n int) bool { return n > 10 }) {
t.Error("expected none > 10")
}
}

Some files were not shown because too many files have changed in this diff Show More