1
0
mirror of https://github.com/AvengeMedia/DankMaterialShell.git synced 2025-12-06 05:25:41 -05:00

core: refactor to use a generic-compatible syncmap

This commit is contained in:
bbedward
2025-11-15 19:44:47 -05:00
parent 4cb652abd9
commit 67557555f2
36 changed files with 936 additions and 543 deletions

View File

@@ -32,14 +32,11 @@ func NewManager() (*Manager, error) {
}, },
stateMutex: sync.RWMutex{}, stateMutex: sync.RWMutex{},
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
dbusConn: conn, dbusConn: conn,
signals: make(chan *dbus.Signal, 256), signals: make(chan *dbus.Signal, 256),
pairingSubscribers: make(map[string]chan PairingPrompt), dirty: make(chan struct{}, 1),
pairingSubMutex: sync.RWMutex{}, eventQueue: make(chan func(), 32),
dirty: make(chan struct{}, 1),
pendingPairings: make(map[string]bool),
eventQueue: make(chan func(), 32),
} }
broker := NewSubscriptionBroker(m.broadcastPairingPrompt) broker := NewSubscriptionBroker(m.broadcastPairingPrompt)
@@ -359,12 +356,7 @@ func (m *Manager) handleDevicePropertiesChanged(path dbus.ObjectPath, changed ma
if hasPaired { if hasPaired {
if paired, ok := pairedVar.Value().(bool); ok && paired { if paired, ok := pairedVar.Value().(bool); ok && paired {
devicePath := string(path) devicePath := string(path)
m.pendingPairingsMux.Lock() _, wasPending := m.pendingPairings.LoadAndDelete(devicePath)
wasPending := m.pendingPairings[devicePath]
if wasPending {
delete(m.pendingPairings, devicePath)
}
m.pendingPairingsMux.Unlock()
if wasPending { if wasPending {
select { select {
@@ -436,8 +428,7 @@ func (m *Manager) notifier() {
continue continue
} }
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan BluetoothState) bool {
ch := value.(chan BluetoothState)
select { select {
case ch <- currentState: case ch <- currentState:
default: default:
@@ -481,38 +472,31 @@ func (m *Manager) Subscribe(id string) chan BluetoothState {
} }
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if val, ok := m.subscribers.LoadAndDelete(id); ok { if ch, ok := m.subscribers.LoadAndDelete(id); ok {
close(val.(chan BluetoothState)) close(ch)
} }
} }
func (m *Manager) SubscribePairing(id string) chan PairingPrompt { func (m *Manager) SubscribePairing(id string) chan PairingPrompt {
ch := make(chan PairingPrompt, 16) ch := make(chan PairingPrompt, 16)
m.pairingSubMutex.Lock() m.pairingSubscribers.Store(id, ch)
m.pairingSubscribers[id] = ch
m.pairingSubMutex.Unlock()
return ch return ch
} }
func (m *Manager) UnsubscribePairing(id string) { func (m *Manager) UnsubscribePairing(id string) {
m.pairingSubMutex.Lock() if ch, ok := m.pairingSubscribers.LoadAndDelete(id); ok {
if ch, ok := m.pairingSubscribers[id]; ok {
close(ch) close(ch)
delete(m.pairingSubscribers, id)
} }
m.pairingSubMutex.Unlock()
} }
func (m *Manager) broadcastPairingPrompt(prompt PairingPrompt) { func (m *Manager) broadcastPairingPrompt(prompt PairingPrompt) {
m.pairingSubMutex.RLock() m.pairingSubscribers.Range(func(key string, ch chan PairingPrompt) bool {
defer m.pairingSubMutex.RUnlock()
for _, ch := range m.pairingSubscribers {
select { select {
case ch <- prompt: case ch <- prompt:
default: default:
} }
} return true
})
} }
func (m *Manager) SubmitPairing(token string, secrets map[string]string, accept bool) error { func (m *Manager) SubmitPairing(token string, secrets map[string]string, accept bool) error {
@@ -553,17 +537,13 @@ func (m *Manager) SetPowered(powered bool) error {
} }
func (m *Manager) PairDevice(devicePath string) error { func (m *Manager) PairDevice(devicePath string) error {
m.pendingPairingsMux.Lock() m.pendingPairings.Store(devicePath, true)
m.pendingPairings[devicePath] = true
m.pendingPairingsMux.Unlock()
obj := m.dbusConn.Object(bluezService, dbus.ObjectPath(devicePath)) obj := m.dbusConn.Object(bluezService, dbus.ObjectPath(devicePath))
err := obj.Call(device1Iface+".Pair", 0).Err err := obj.Call(device1Iface+".Pair", 0).Err
if err != nil { if err != nil {
m.pendingPairingsMux.Lock() m.pendingPairings.Delete(devicePath)
delete(m.pendingPairings, devicePath)
m.pendingPairingsMux.Unlock()
} }
return err return err
@@ -605,19 +585,17 @@ func (m *Manager) Close() {
m.agent.Close() m.agent.Close()
} }
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan BluetoothState) bool {
ch := value.(chan BluetoothState)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true
}) })
m.pairingSubMutex.Lock() m.pairingSubscribers.Range(func(key string, ch chan PairingPrompt) bool {
for _, ch := range m.pairingSubscribers {
close(ch) close(ch)
} m.pairingSubscribers.Delete(key)
m.pairingSubscribers = make(map[string]chan PairingPrompt) return true
m.pairingSubMutex.Unlock() })
if m.dbusConn != nil { if m.dbusConn != nil {
m.dbusConn.Close() m.dbusConn.Close()

View File

@@ -3,22 +3,19 @@ package bluez
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"github.com/AvengeMedia/DankMaterialShell/core/internal/errdefs" "github.com/AvengeMedia/DankMaterialShell/core/internal/errdefs"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
) )
type SubscriptionBroker struct { type SubscriptionBroker struct {
mu sync.RWMutex pending syncmap.Map[string, chan PromptReply]
pending map[string]chan PromptReply requests syncmap.Map[string, PromptRequest]
requests map[string]PromptRequest
broadcastPrompt func(PairingPrompt) broadcastPrompt func(PairingPrompt)
} }
func NewSubscriptionBroker(broadcastPrompt func(PairingPrompt)) PromptBroker { func NewSubscriptionBroker(broadcastPrompt func(PairingPrompt)) PromptBroker {
return &SubscriptionBroker{ return &SubscriptionBroker{
pending: make(map[string]chan PromptReply),
requests: make(map[string]PromptRequest),
broadcastPrompt: broadcastPrompt, broadcastPrompt: broadcastPrompt,
} }
} }
@@ -30,10 +27,8 @@ func (b *SubscriptionBroker) Ask(ctx context.Context, req PromptRequest) (string
} }
replyChan := make(chan PromptReply, 1) replyChan := make(chan PromptReply, 1)
b.mu.Lock() b.pending.Store(token, replyChan)
b.pending[token] = replyChan b.requests.Store(token, req)
b.requests[token] = req
b.mu.Unlock()
if b.broadcastPrompt != nil { if b.broadcastPrompt != nil {
prompt := PairingPrompt{ prompt := PairingPrompt{
@@ -53,10 +48,7 @@ func (b *SubscriptionBroker) Ask(ctx context.Context, req PromptRequest) (string
} }
func (b *SubscriptionBroker) Wait(ctx context.Context, token string) (PromptReply, error) { func (b *SubscriptionBroker) Wait(ctx context.Context, token string) (PromptReply, error) {
b.mu.RLock() replyChan, exists := b.pending.Load(token)
replyChan, exists := b.pending[token]
b.mu.RUnlock()
if !exists { if !exists {
return PromptReply{}, fmt.Errorf("unknown token: %s", token) return PromptReply{}, fmt.Errorf("unknown token: %s", token)
} }
@@ -75,10 +67,7 @@ func (b *SubscriptionBroker) Wait(ctx context.Context, token string) (PromptRepl
} }
func (b *SubscriptionBroker) Resolve(token string, reply PromptReply) error { func (b *SubscriptionBroker) Resolve(token string, reply PromptReply) error {
b.mu.RLock() replyChan, exists := b.pending.Load(token)
replyChan, exists := b.pending[token]
b.mu.RUnlock()
if !exists { if !exists {
return fmt.Errorf("unknown or expired token: %s", token) return fmt.Errorf("unknown or expired token: %s", token)
} }
@@ -92,8 +81,6 @@ func (b *SubscriptionBroker) Resolve(token string, reply PromptReply) error {
} }
func (b *SubscriptionBroker) cleanup(token string) { func (b *SubscriptionBroker) cleanup(token string) {
b.mu.Lock() b.pending.Delete(token)
delete(b.pending, token) b.requests.Delete(token)
delete(b.requests, token)
b.mu.Unlock()
} }

View File

@@ -3,6 +3,7 @@ package bluez
import ( import (
"sync" "sync"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
) )
@@ -59,21 +60,19 @@ type PairingPrompt struct {
type Manager struct { type Manager struct {
state *BluetoothState state *BluetoothState
stateMutex sync.RWMutex stateMutex sync.RWMutex
subscribers sync.Map subscribers syncmap.Map[string, chan BluetoothState]
stopChan chan struct{} stopChan chan struct{}
dbusConn *dbus.Conn dbusConn *dbus.Conn
signals chan *dbus.Signal signals chan *dbus.Signal
sigWG sync.WaitGroup sigWG sync.WaitGroup
agent *BluezAgent agent *BluezAgent
promptBroker PromptBroker promptBroker PromptBroker
pairingSubscribers map[string]chan PairingPrompt pairingSubscribers syncmap.Map[string, chan PairingPrompt]
pairingSubMutex sync.RWMutex
dirty chan struct{} dirty chan struct{}
notifierWg sync.WaitGroup notifierWg sync.WaitGroup
lastNotifiedState *BluetoothState lastNotifiedState *BluetoothState
adapterPath dbus.ObjectPath adapterPath dbus.ObjectPath
pendingPairings map[string]bool pendingPairings syncmap.Map[string, bool]
pendingPairingsMux sync.Mutex
eventQueue chan func() eventQueue chan func()
eventWg sync.WaitGroup eventWg sync.WaitGroup
} }

View File

@@ -24,7 +24,6 @@ const (
func NewDDCBackend() (*DDCBackend, error) { func NewDDCBackend() (*DDCBackend, error) {
b := &DDCBackend{ b := &DDCBackend{
devices: make(map[string]*ddcDevice),
scanInterval: 30 * time.Second, scanInterval: 30 * time.Second,
debounceTimers: make(map[string]*time.Timer), debounceTimers: make(map[string]*time.Timer),
debouncePending: make(map[string]ddcPendingSet), debouncePending: make(map[string]ddcPendingSet),
@@ -53,10 +52,10 @@ func (b *DDCBackend) scanI2CDevices() error {
return nil return nil
} }
b.devicesMutex.Lock() b.devices.Range(func(key string, value *ddcDevice) bool {
defer b.devicesMutex.Unlock() b.devices.Delete(key)
return true
b.devices = make(map[string]*ddcDevice) })
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
busPath := fmt.Sprintf("/dev/i2c-%d", i) busPath := fmt.Sprintf("/dev/i2c-%d", i)
@@ -64,7 +63,6 @@ func (b *DDCBackend) scanI2CDevices() error {
continue continue
} }
// Skip SMBus, GPU internal buses (e.g. AMDGPU SMU) to prevent GPU hangs
if isIgnorableI2CBus(i) { if isIgnorableI2CBus(i) {
log.Debugf("Skipping ignorable i2c-%d", i) log.Debugf("Skipping ignorable i2c-%d", i)
continue continue
@@ -77,7 +75,7 @@ func (b *DDCBackend) scanI2CDevices() error {
id := fmt.Sprintf("ddc:i2c-%d", i) id := fmt.Sprintf("ddc:i2c-%d", i)
dev.id = id dev.id = id
b.devices[id] = dev b.devices.Store(id, dev)
log.Debugf("found DDC device on i2c-%d", i) log.Debugf("found DDC device on i2c-%d", i)
} }
@@ -164,12 +162,9 @@ func (b *DDCBackend) GetDevices() ([]Device, error) {
log.Debugf("DDC scan error: %v", err) log.Debugf("DDC scan error: %v", err)
} }
b.devicesMutex.Lock() devices := make([]Device, 0)
defer b.devicesMutex.Unlock()
devices := make([]Device, 0, len(b.devices)) b.devices.Range(func(id string, dev *ddcDevice) bool {
for id, dev := range b.devices {
devices = append(devices, Device{ devices = append(devices, Device{
Class: ClassDDC, Class: ClassDDC,
ID: id, ID: id,
@@ -179,7 +174,8 @@ func (b *DDCBackend) GetDevices() ([]Device, error) {
CurrentPercent: dev.lastBrightness, CurrentPercent: dev.lastBrightness,
Backend: "ddc", Backend: "ddc",
}) })
} return true
})
return devices, nil return devices, nil
} }
@@ -189,9 +185,7 @@ func (b *DDCBackend) SetBrightness(id string, value int, exponential bool, callb
} }
func (b *DDCBackend) SetBrightnessWithExponent(id string, value int, exponential bool, exponent float64, callback func()) error { func (b *DDCBackend) SetBrightnessWithExponent(id string, value int, exponential bool, exponent float64, callback func()) error {
b.devicesMutex.RLock() _, ok := b.devices.Load(id)
_, ok := b.devices[id]
b.devicesMutex.RUnlock()
if !ok { if !ok {
return fmt.Errorf("device not found: %s", id) return fmt.Errorf("device not found: %s", id)
@@ -202,8 +196,6 @@ func (b *DDCBackend) SetBrightnessWithExponent(id string, value int, exponential
} }
b.debounceMutex.Lock() b.debounceMutex.Lock()
defer b.debounceMutex.Unlock()
b.debouncePending[id] = ddcPendingSet{ b.debouncePending[id] = ddcPendingSet{
percent: value, percent: value,
callback: callback, callback: callback,
@@ -234,14 +226,13 @@ func (b *DDCBackend) SetBrightnessWithExponent(id string, value int, exponential
} }
}) })
} }
b.debounceMutex.Unlock()
return nil return nil
} }
func (b *DDCBackend) setBrightnessImmediateWithExponent(id string, value int) error { func (b *DDCBackend) setBrightnessImmediateWithExponent(id string, value int) error {
b.devicesMutex.RLock() dev, ok := b.devices.Load(id)
dev, ok := b.devices[id]
b.devicesMutex.RUnlock()
if !ok { if !ok {
return fmt.Errorf("device not found: %s", id) return fmt.Errorf("device not found: %s", id)
@@ -266,9 +257,8 @@ func (b *DDCBackend) setBrightnessImmediateWithExponent(id string, value int) er
return fmt.Errorf("get current capability: %w", err) return fmt.Errorf("get current capability: %w", err)
} }
max = cap.max max = cap.max
b.devicesMutex.Lock()
dev.max = max dev.max = max
b.devicesMutex.Unlock() b.devices.Store(id, dev)
} }
if err := b.setVCPFeature(fd, VCP_BRIGHTNESS, value); err != nil { if err := b.setVCPFeature(fd, VCP_BRIGHTNESS, value); err != nil {
@@ -277,10 +267,9 @@ func (b *DDCBackend) setBrightnessImmediateWithExponent(id string, value int) er
log.Debugf("set %s to %d/%d", id, value, max) log.Debugf("set %s to %d/%d", id, value, max)
b.devicesMutex.Lock()
dev.max = max dev.max = max
dev.lastBrightness = value dev.lastBrightness = value
b.devicesMutex.Unlock() b.devices.Store(id, dev)
return nil return nil
} }

View File

@@ -360,8 +360,7 @@ func (m *Manager) broadcastDeviceUpdate(deviceID string) {
log.Debugf("Broadcasting device update: %s at %d%%", deviceID, targetDevice.CurrentPercent) log.Debugf("Broadcasting device update: %s at %d%%", deviceID, targetDevice.CurrentPercent)
m.updateSubscribers.Range(func(key, value interface{}) bool { m.updateSubscribers.Range(func(key string, ch chan DeviceUpdate) bool {
ch := value.(chan DeviceUpdate)
select { select {
case ch <- update: case ch <- update:
default: default:

View File

@@ -13,9 +13,8 @@ import (
func NewSysfsBackend() (*SysfsBackend, error) { func NewSysfsBackend() (*SysfsBackend, error) {
b := &SysfsBackend{ b := &SysfsBackend{
basePath: "/sys/class", basePath: "/sys/class",
classes: []string{"backlight", "leds"}, classes: []string{"backlight", "leds"},
deviceCache: make(map[string]*sysfsDevice),
} }
if err := b.scanDevices(); err != nil { if err := b.scanDevices(); err != nil {
@@ -26,9 +25,6 @@ func NewSysfsBackend() (*SysfsBackend, error) {
} }
func (b *SysfsBackend) scanDevices() error { func (b *SysfsBackend) scanDevices() error {
b.deviceCacheMutex.Lock()
defer b.deviceCacheMutex.Unlock()
for _, class := range b.classes { for _, class := range b.classes {
classPath := filepath.Join(b.basePath, class) classPath := filepath.Join(b.basePath, class)
entries, err := os.ReadDir(classPath) entries, err := os.ReadDir(classPath)
@@ -68,13 +64,13 @@ func (b *SysfsBackend) scanDevices() error {
} }
deviceID := fmt.Sprintf("%s:%s", class, entry.Name()) deviceID := fmt.Sprintf("%s:%s", class, entry.Name())
b.deviceCache[deviceID] = &sysfsDevice{ b.deviceCache.Store(deviceID, &sysfsDevice{
class: deviceClass, class: deviceClass,
id: deviceID, id: deviceID,
name: entry.Name(), name: entry.Name(),
maxBrightness: maxBrightness, maxBrightness: maxBrightness,
minValue: minValue, minValue: minValue,
} })
log.Debugf("found %s device: %s (max=%d)", class, entry.Name(), maxBrightness) log.Debugf("found %s device: %s (max=%d)", class, entry.Name(), maxBrightness)
} }
@@ -106,19 +102,16 @@ func shouldSuppressDevice(name string) bool {
} }
func (b *SysfsBackend) GetDevices() ([]Device, error) { func (b *SysfsBackend) GetDevices() ([]Device, error) {
b.deviceCacheMutex.RLock() devices := make([]Device, 0)
defer b.deviceCacheMutex.RUnlock()
devices := make([]Device, 0, len(b.deviceCache)) b.deviceCache.Range(func(key string, dev *sysfsDevice) bool {
for _, dev := range b.deviceCache {
if shouldSuppressDevice(dev.name) { if shouldSuppressDevice(dev.name) {
continue return true
} }
parts := strings.SplitN(dev.id, ":", 2) parts := strings.SplitN(dev.id, ":", 2)
if len(parts) != 2 { if len(parts) != 2 {
continue return true
} }
class := parts[0] class := parts[0]
@@ -130,13 +123,13 @@ func (b *SysfsBackend) GetDevices() ([]Device, error) {
brightnessData, err := os.ReadFile(brightnessPath) brightnessData, err := os.ReadFile(brightnessPath)
if err != nil { if err != nil {
log.Debugf("failed to read brightness for %s: %v", dev.id, err) log.Debugf("failed to read brightness for %s: %v", dev.id, err)
continue return true
} }
current, err := strconv.Atoi(strings.TrimSpace(string(brightnessData))) current, err := strconv.Atoi(strings.TrimSpace(string(brightnessData)))
if err != nil { if err != nil {
log.Debugf("failed to parse brightness for %s: %v", dev.id, err) log.Debugf("failed to parse brightness for %s: %v", dev.id, err)
continue return true
} }
percent := b.ValueToPercent(current, dev, false) percent := b.ValueToPercent(current, dev, false)
@@ -150,16 +143,14 @@ func (b *SysfsBackend) GetDevices() ([]Device, error) {
CurrentPercent: percent, CurrentPercent: percent,
Backend: "sysfs", Backend: "sysfs",
}) })
} return true
})
return devices, nil return devices, nil
} }
func (b *SysfsBackend) GetDevice(id string) (*sysfsDevice, error) { func (b *SysfsBackend) GetDevice(id string) (*sysfsDevice, error) {
b.deviceCacheMutex.RLock() dev, ok := b.deviceCache.Load(id)
defer b.deviceCacheMutex.RUnlock()
dev, ok := b.deviceCache[id]
if !ok { if !ok {
return nil, fmt.Errorf("device not found: %s", id) return nil, fmt.Errorf("device not found: %s", id)
} }

View File

@@ -31,9 +31,8 @@ func TestManager_SetBrightness_LogindSuccess(t *testing.T) {
mockLogind := NewLogindBackendWithConn(mockConn) mockLogind := NewLogindBackendWithConn(mockConn)
sysfs := &SysfsBackend{ sysfs := &SysfsBackend{
basePath: tmpDir, basePath: tmpDir,
classes: []string{"backlight"}, classes: []string{"backlight"},
deviceCache: make(map[string]*sysfsDevice),
} }
if err := sysfs.scanDevices(); err != nil { if err := sysfs.scanDevices(); err != nil {
@@ -103,9 +102,8 @@ func TestManager_SetBrightness_LogindFailsFallbackToSysfs(t *testing.T) {
mockLogind := NewLogindBackendWithConn(mockConn) mockLogind := NewLogindBackendWithConn(mockConn)
sysfs := &SysfsBackend{ sysfs := &SysfsBackend{
basePath: tmpDir, basePath: tmpDir,
classes: []string{"backlight"}, classes: []string{"backlight"},
deviceCache: make(map[string]*sysfsDevice),
} }
if err := sysfs.scanDevices(); err != nil { if err := sysfs.scanDevices(); err != nil {
@@ -171,9 +169,8 @@ func TestManager_SetBrightness_NoLogind(t *testing.T) {
} }
sysfs := &SysfsBackend{ sysfs := &SysfsBackend{
basePath: tmpDir, basePath: tmpDir,
classes: []string{"backlight"}, classes: []string{"backlight"},
deviceCache: make(map[string]*sysfsDevice),
} }
if err := sysfs.scanDevices(); err != nil { if err := sysfs.scanDevices(); err != nil {
@@ -234,9 +231,8 @@ func TestManager_SetBrightness_LEDWithLogind(t *testing.T) {
mockLogind := NewLogindBackendWithConn(mockConn) mockLogind := NewLogindBackendWithConn(mockConn)
sysfs := &SysfsBackend{ sysfs := &SysfsBackend{
basePath: tmpDir, basePath: tmpDir,
classes: []string{"leds"}, classes: []string{"leds"},
deviceCache: make(map[string]*sysfsDevice),
} }
if err := sysfs.scanDevices(); err != nil { if err := sysfs.scanDevices(); err != nil {

View File

@@ -160,26 +160,21 @@ func TestSysfsBackend_ScanDevices(t *testing.T) {
} }
b := &SysfsBackend{ b := &SysfsBackend{
basePath: tmpDir, basePath: tmpDir,
classes: []string{"backlight", "leds"}, classes: []string{"backlight", "leds"},
deviceCache: make(map[string]*sysfsDevice),
} }
if err := b.scanDevices(); err != nil { if err := b.scanDevices(); err != nil {
t.Fatalf("scanDevices() error = %v", err) t.Fatalf("scanDevices() error = %v", err)
} }
if len(b.deviceCache) != 2 {
t.Errorf("expected 2 devices, got %d", len(b.deviceCache))
}
backlightID := "backlight:test_backlight" backlightID := "backlight:test_backlight"
if _, ok := b.deviceCache[backlightID]; !ok { if _, ok := b.deviceCache.Load(backlightID); !ok {
t.Errorf("backlight device not found") t.Errorf("backlight device not found")
} }
ledID := "leds:test_led" ledID := "leds:test_led"
if _, ok := b.deviceCache[ledID]; !ok { if _, ok := b.deviceCache.Load(ledID); !ok {
t.Errorf("LED device not found") t.Errorf("LED device not found")
} }
} }

View File

@@ -3,6 +3,8 @@ package brightness
import ( import (
"sync" "sync"
"time" "time"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
) )
type DeviceClass string type DeviceClass string
@@ -51,8 +53,8 @@ type Manager struct {
stateMutex sync.RWMutex stateMutex sync.RWMutex
state State state State
subscribers sync.Map subscribers syncmap.Map[string, chan State]
updateSubscribers sync.Map updateSubscribers syncmap.Map[string, chan DeviceUpdate]
broadcastMutex sync.Mutex broadcastMutex sync.Mutex
broadcastTimer *time.Timer broadcastTimer *time.Timer
@@ -66,8 +68,7 @@ type SysfsBackend struct {
basePath string basePath string
classes []string classes []string
deviceCache map[string]*sysfsDevice deviceCache syncmap.Map[string, *sysfsDevice]
deviceCacheMutex sync.RWMutex
} }
type sysfsDevice struct { type sysfsDevice struct {
@@ -79,8 +80,7 @@ type sysfsDevice struct {
} }
type DDCBackend struct { type DDCBackend struct {
devices map[string]*ddcDevice devices syncmap.Map[string, *ddcDevice]
devicesMutex sync.RWMutex
scanMutex sync.Mutex scanMutex sync.Mutex
lastScan time.Time lastScan time.Time
@@ -129,7 +129,7 @@ func (m *Manager) Subscribe(id string) chan State {
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if val, ok := m.subscribers.LoadAndDelete(id); ok { if val, ok := m.subscribers.LoadAndDelete(id); ok {
close(val.(chan State)) close(val)
} }
@@ -143,7 +143,7 @@ func (m *Manager) SubscribeUpdates(id string) chan DeviceUpdate {
func (m *Manager) UnsubscribeUpdates(id string) { func (m *Manager) UnsubscribeUpdates(id string) {
if val, ok := m.updateSubscribers.LoadAndDelete(id); ok { if val, ok := m.updateSubscribers.LoadAndDelete(id); ok {
close(val.(chan DeviceUpdate)) close(val)
} }
} }
@@ -152,8 +152,7 @@ func (m *Manager) NotifySubscribers() {
state := m.state state := m.state
m.stateMutex.RUnlock() m.stateMutex.RUnlock()
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
select { select {
case ch <- state: case ch <- state:
default: default:
@@ -171,14 +170,12 @@ func (m *Manager) GetState() State {
func (m *Manager) Close() { func (m *Manager) Close() {
close(m.stopChan) close(m.stopChan)
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true
}) })
m.updateSubscribers.Range(func(key, value interface{}) bool { m.updateSubscribers.Range(func(key string, ch chan DeviceUpdate) bool {
ch := value.(chan DeviceUpdate)
close(ch) close(ch)
m.updateSubscribers.Delete(key) m.updateSubscribers.Delete(key)
return true return true

View File

@@ -148,8 +148,7 @@ func (m *Manager) notifier() {
continue continue
} }
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan CUPSState) bool {
ch := value.(chan CUPSState)
select { select {
case ch <- currentState: case ch <- currentState:
default: default:
@@ -193,7 +192,7 @@ func (m *Manager) Subscribe(id string) chan CUPSState {
ch := make(chan CUPSState, 64) ch := make(chan CUPSState, 64)
wasEmpty := true wasEmpty := true
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan CUPSState) bool {
wasEmpty = false wasEmpty = false
return false return false
}) })
@@ -214,11 +213,11 @@ func (m *Manager) Subscribe(id string) chan CUPSState {
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if val, ok := m.subscribers.LoadAndDelete(id); ok { if val, ok := m.subscribers.LoadAndDelete(id); ok {
close(val.(chan CUPSState)) close(val)
} }
isEmpty := true isEmpty := true
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan CUPSState) bool {
isEmpty = false isEmpty = false
return false return false
}) })
@@ -239,8 +238,7 @@ func (m *Manager) Close() {
m.eventWG.Wait() m.eventWG.Wait()
m.notifierWg.Wait() m.notifierWg.Wait()
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan CUPSState) bool {
ch := value.(chan CUPSState)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true

View File

@@ -60,7 +60,7 @@ func TestManager_Subscribe(t *testing.T) {
assert.NotNil(t, ch) assert.NotNil(t, ch)
count := 0 count := 0
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan CUPSState) bool {
count++ count++
return true return true
}) })
@@ -68,7 +68,7 @@ func TestManager_Subscribe(t *testing.T) {
m.Unsubscribe("test-client") m.Unsubscribe("test-client")
count = 0 count = 0
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan CUPSState) bool {
count++ count++
return true return true
}) })
@@ -101,7 +101,7 @@ func TestManager_Close(t *testing.T) {
m.Close() m.Close()
count := 0 count := 0
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan CUPSState) bool {
count++ count++
return true return true
}) })

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/ipp" "github.com/AvengeMedia/DankMaterialShell/core/pkg/ipp"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
) )
type CUPSState struct { type CUPSState struct {
@@ -39,7 +40,7 @@ type Manager struct {
client CUPSClientInterface client CUPSClientInterface
subscription SubscriptionManagerInterface subscription SubscriptionManagerInterface
stateMutex sync.RWMutex stateMutex sync.RWMutex
subscribers sync.Map subscribers syncmap.Map[string, chan CUPSState]
stopChan chan struct{} stopChan chan struct{}
eventWG sync.WaitGroup eventWG sync.WaitGroup
dirty chan struct{} dirty chan struct{}

View File

@@ -14,7 +14,6 @@ func NewManager(display *wlclient.Display) (*Manager, error) {
m := &Manager{ m := &Manager{
display: display, display: display,
ctx: display.Context(), ctx: display.Context(),
outputs: make(map[uint32]*outputState),
cmdq: make(chan cmd, 128), cmdq: make(chan cmd, 128),
outputSetupReq: make(chan uint32, 16), outputSetupReq: make(chan uint32, 16),
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
@@ -56,10 +55,7 @@ func (m *Manager) waylandActor() {
case c := <-m.cmdq: case c := <-m.cmdq:
c.fn() c.fn()
case outputID := <-m.outputSetupReq: case outputID := <-m.outputSetupReq:
m.outputsMutex.RLock() out, exists := m.outputs.Load(outputID)
out, exists := m.outputs[outputID]
m.outputsMutex.RUnlock()
if !exists { if !exists {
log.Warnf("DWL: Output %d no longer exists, skipping setup", outputID) log.Warnf("DWL: Output %d no longer exists, skipping setup", outputID)
continue continue
@@ -156,9 +152,7 @@ func (m *Manager) setupRegistry() error {
outputs = append(outputs, output) outputs = append(outputs, output)
outputRegNames[outputID] = e.Name outputRegNames[outputID] = e.Name
m.outputsMutex.Lock() m.outputs.Store(outputID, outState)
m.outputs[outputID] = outState
m.outputsMutex.Unlock()
if m.manager != nil { if m.manager != nil {
select { select {
@@ -176,17 +170,16 @@ func (m *Manager) setupRegistry() error {
registry.SetGlobalRemoveHandler(func(e wlclient.RegistryGlobalRemoveEvent) { registry.SetGlobalRemoveHandler(func(e wlclient.RegistryGlobalRemoveEvent) {
m.post(func() { m.post(func() {
m.outputsMutex.Lock()
var outToRelease *outputState var outToRelease *outputState
for id, out := range m.outputs { m.outputs.Range(func(id uint32, out *outputState) bool {
if out.registryName == e.Name { if out.registryName == e.Name {
log.Infof("DWL: Output %d removed", id) log.Infof("DWL: Output %d removed", id)
outToRelease = out outToRelease = out
delete(m.outputs, id) m.outputs.Delete(id)
break return false
} }
} return true
m.outputsMutex.Unlock() })
if outToRelease != nil { if outToRelease != nil {
if ipcOut, ok := outToRelease.ipcOutput.(*dwl_ipc.ZdwlIpcOutputV2); ok && ipcOut != nil { if ipcOut, ok := outToRelease.ipcOutput.(*dwl_ipc.ZdwlIpcOutputV2); ok && ipcOut != nil {
@@ -236,14 +229,11 @@ func (m *Manager) setupOutput(manager *dwl_ipc.ZdwlIpcManagerV2, output *wlclien
return fmt.Errorf("failed to get dwl output: %w", err) return fmt.Errorf("failed to get dwl output: %w", err)
} }
m.outputsMutex.Lock() outState, exists := m.outputs.Load(output.ID())
outState, exists := m.outputs[output.ID()]
if !exists { if !exists {
m.outputsMutex.Unlock()
return fmt.Errorf("output state not found for id %d", output.ID()) return fmt.Errorf("output state not found for id %d", output.ID())
} }
outState.ipcOutput = ipcOutput outState.ipcOutput = ipcOutput
m.outputsMutex.Unlock()
ipcOutput.SetActiveHandler(func(e dwl_ipc.ZdwlIpcOutputV2ActiveEvent) { ipcOutput.SetActiveHandler(func(e dwl_ipc.ZdwlIpcOutputV2ActiveEvent) {
outState.active = e.Active outState.active = e.Active
@@ -300,11 +290,10 @@ func (m *Manager) setupOutput(manager *dwl_ipc.ZdwlIpcManagerV2, output *wlclien
} }
func (m *Manager) updateState() { func (m *Manager) updateState() {
m.outputsMutex.RLock()
outputs := make(map[string]*OutputState) outputs := make(map[string]*OutputState)
activeOutput := "" activeOutput := ""
for _, out := range m.outputs { m.outputs.Range(func(key uint32, out *outputState) bool {
name := out.name name := out.name
if name == "" { if name == "" {
name = fmt.Sprintf("output-%d", out.id) name = fmt.Sprintf("output-%d", out.id)
@@ -326,8 +315,8 @@ func (m *Manager) updateState() {
if out.active != 0 { if out.active != 0 {
activeOutput = name activeOutput = name
} }
} return true
m.outputsMutex.RUnlock() })
newState := State{ newState := State{
Outputs: outputs, Outputs: outputs,
@@ -373,8 +362,7 @@ func (m *Manager) notifier() {
continue continue
} }
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
select { select {
case ch <- currentState: case ch <- currentState:
default: default:
@@ -399,11 +387,9 @@ func (m *Manager) ensureOutputSetup(out *outputState) error {
} }
func (m *Manager) SetTags(outputName string, tagmask uint32, toggleTagset uint32) error { func (m *Manager) SetTags(outputName string, tagmask uint32, toggleTagset uint32) error {
m.outputsMutex.RLock() availableOutputs := make([]string, 0)
availableOutputs := make([]string, 0, len(m.outputs))
var targetOut *outputState var targetOut *outputState
for _, out := range m.outputs { m.outputs.Range(func(key uint32, out *outputState) bool {
name := out.name name := out.name
if name == "" { if name == "" {
name = fmt.Sprintf("output-%d", out.id) name = fmt.Sprintf("output-%d", out.id)
@@ -411,10 +397,10 @@ func (m *Manager) SetTags(outputName string, tagmask uint32, toggleTagset uint32
availableOutputs = append(availableOutputs, name) availableOutputs = append(availableOutputs, name)
if name == outputName { if name == outputName {
targetOut = out targetOut = out
break return false
} }
} return true
m.outputsMutex.RUnlock() })
if targetOut == nil { if targetOut == nil {
return fmt.Errorf("output not found: %s (available: %v)", outputName, availableOutputs) return fmt.Errorf("output not found: %s (available: %v)", outputName, availableOutputs)
@@ -436,20 +422,18 @@ func (m *Manager) SetTags(outputName string, tagmask uint32, toggleTagset uint32
} }
func (m *Manager) SetClientTags(outputName string, andTags uint32, xorTags uint32) error { func (m *Manager) SetClientTags(outputName string, andTags uint32, xorTags uint32) error {
m.outputsMutex.RLock()
var targetOut *outputState var targetOut *outputState
for _, out := range m.outputs { m.outputs.Range(func(key uint32, out *outputState) bool {
name := out.name name := out.name
if name == "" { if name == "" {
name = fmt.Sprintf("output-%d", out.id) name = fmt.Sprintf("output-%d", out.id)
} }
if name == outputName { if name == outputName {
targetOut = out targetOut = out
break return false
} }
} return true
m.outputsMutex.RUnlock() })
if targetOut == nil { if targetOut == nil {
return fmt.Errorf("output not found: %s", outputName) return fmt.Errorf("output not found: %s", outputName)
@@ -471,20 +455,18 @@ func (m *Manager) SetClientTags(outputName string, andTags uint32, xorTags uint3
} }
func (m *Manager) SetLayout(outputName string, index uint32) error { func (m *Manager) SetLayout(outputName string, index uint32) error {
m.outputsMutex.RLock()
var targetOut *outputState var targetOut *outputState
for _, out := range m.outputs { m.outputs.Range(func(key uint32, out *outputState) bool {
name := out.name name := out.name
if name == "" { if name == "" {
name = fmt.Sprintf("output-%d", out.id) name = fmt.Sprintf("output-%d", out.id)
} }
if name == outputName { if name == outputName {
targetOut = out targetOut = out
break return false
} }
} return true
m.outputsMutex.RUnlock() })
if targetOut == nil { if targetOut == nil {
return fmt.Errorf("output not found: %s", outputName) return fmt.Errorf("output not found: %s", outputName)
@@ -510,21 +492,19 @@ func (m *Manager) Close() {
m.wg.Wait() m.wg.Wait()
m.notifierWg.Wait() m.notifierWg.Wait()
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true
}) })
m.outputsMutex.Lock() m.outputs.Range(func(key uint32, out *outputState) bool {
for _, out := range m.outputs {
if ipcOut, ok := out.ipcOutput.(*dwl_ipc.ZdwlIpcOutputV2); ok { if ipcOut, ok := out.ipcOutput.(*dwl_ipc.ZdwlIpcOutputV2); ok {
ipcOut.Release() ipcOut.Release()
} }
} m.outputs.Delete(key)
m.outputs = make(map[uint32]*outputState) return true
m.outputsMutex.Unlock() })
if mgr, ok := m.manager.(*dwl_ipc.ZdwlIpcManagerV2); ok { if mgr, ok := m.manager.(*dwl_ipc.ZdwlIpcManagerV2); ok {
mgr.Release() mgr.Release()

View File

@@ -4,6 +4,7 @@ import (
"sync" "sync"
wlclient "github.com/AvengeMedia/DankMaterialShell/core/pkg/go-wayland/wayland/client" wlclient "github.com/AvengeMedia/DankMaterialShell/core/pkg/go-wayland/wayland/client"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
) )
type TagState struct { type TagState struct {
@@ -40,8 +41,7 @@ type Manager struct {
registry *wlclient.Registry registry *wlclient.Registry
manager interface{} manager interface{}
outputs map[uint32]*outputState outputs syncmap.Map[uint32, *outputState]
outputsMutex sync.RWMutex
tagCount uint32 tagCount uint32
layouts []string layouts []string
@@ -52,7 +52,7 @@ type Manager struct {
stopChan chan struct{} stopChan chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
subscribers sync.Map subscribers syncmap.Map[string, chan State]
dirty chan struct{} dirty chan struct{}
notifierWg sync.WaitGroup notifierWg sync.WaitGroup
lastNotified *State lastNotified *State
@@ -98,12 +98,9 @@ func (m *Manager) Subscribe(id string) chan State {
} }
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if val, ok := m.subscribers.LoadAndDelete(id); ok { if val, ok := m.subscribers.LoadAndDelete(id); ok {
close(val.(chan State)) close(val)
} }
} }
func (m *Manager) notifySubscribers() { func (m *Manager) notifySubscribers() {

View File

@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/AvengeMedia/DankMaterialShell/core/internal/log" "github.com/AvengeMedia/DankMaterialShell/core/internal/log"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
evdev "github.com/holoplot/go-evdev" evdev "github.com/holoplot/go-evdev"
) )
@@ -35,7 +36,7 @@ type Manager struct {
monitoredPaths map[string]bool monitoredPaths map[string]bool
state State state State
stateMutex sync.RWMutex stateMutex sync.RWMutex
subscribers sync.Map subscribers syncmap.Map[string, chan State]
closeChan chan struct{} closeChan chan struct{}
closeOnce sync.Once closeOnce sync.Once
watcher *fsnotify.Watcher watcher *fsnotify.Watcher
@@ -338,13 +339,12 @@ func (m *Manager) Subscribe(id string) chan State {
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if val, ok := m.subscribers.LoadAndDelete(id); ok { if val, ok := m.subscribers.LoadAndDelete(id); ok {
close(val.(chan State)) close(val)
} }
} }
func (m *Manager) notifySubscribers(state State) { func (m *Manager) notifySubscribers(state State) {
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
select { select {
case ch <- state: case ch <- state:
default: default:
@@ -372,8 +372,7 @@ func (m *Manager) Close() {
} }
m.devicesMutex.Unlock() m.devicesMutex.Unlock()
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true

View File

@@ -72,7 +72,7 @@ func TestManager_Subscribe(t *testing.T) {
ch := m.Subscribe("test-client") ch := m.Subscribe("test-client")
assert.NotNil(t, ch) assert.NotNil(t, ch)
count := 0 count := 0
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
count++ count++
return true return true
}) })
@@ -92,7 +92,7 @@ func TestManager_Unsubscribe(t *testing.T) {
ch := m.Subscribe("test-client") ch := m.Subscribe("test-client")
count := 0 count := 0
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
count++ count++
return true return true
}) })
@@ -100,7 +100,7 @@ func TestManager_Unsubscribe(t *testing.T) {
m.Unsubscribe("test-client") m.Unsubscribe("test-client")
count = 0 count = 0
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
count++ count++
return true return true
}) })
@@ -180,7 +180,7 @@ func TestManager_Close(t *testing.T) {
} }
count := 0 count := 0
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
count++ count++
return true return true
}) })

View File

@@ -11,14 +11,10 @@ import (
func NewManager(display *wlclient.Display) (*Manager, error) { func NewManager(display *wlclient.Display) (*Manager, error) {
m := &Manager{ m := &Manager{
display: display, display: display,
ctx: display.Context(), ctx: display.Context(),
outputs: make(map[uint32]*wlclient.Output), cmdq: make(chan cmd, 128),
outputNames: make(map[uint32]string), stopChan: make(chan struct{}),
groups: make(map[uint32]*workspaceGroupState),
workspaces: make(map[uint32]*workspaceState),
cmdq: make(chan cmd, 128),
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1), dirty: make(chan struct{}, 1),
} }
@@ -77,9 +73,7 @@ func (m *Manager) setupRegistry() error {
outputID := output.ID() outputID := output.ID()
output.SetNameHandler(func(ev wlclient.OutputNameEvent) { output.SetNameHandler(func(ev wlclient.OutputNameEvent) {
m.outputsMutex.Lock() m.outputNames.Store(outputID, ev.Name)
m.outputNames[outputID] = ev.Name
m.outputsMutex.Unlock()
log.Debugf("ExtWorkspace: Output %d (%s) name received", outputID, ev.Name) log.Debugf("ExtWorkspace: Output %d (%s) name received", outputID, ev.Name)
}) })
} }
@@ -139,9 +133,7 @@ func (m *Manager) handleWorkspaceGroup(e ext_workspace.ExtWorkspaceManagerV1Work
workspaceIDs: make([]uint32, 0), workspaceIDs: make([]uint32, 0),
} }
m.groupsMutex.Lock() m.groups.Store(groupID, group)
m.groups[groupID] = group
m.groupsMutex.Unlock()
handle.SetCapabilitiesHandler(func(e ext_workspace.ExtWorkspaceGroupHandleV1CapabilitiesEvent) { handle.SetCapabilitiesHandler(func(e ext_workspace.ExtWorkspaceGroupHandleV1CapabilitiesEvent) {
log.Debugf("ExtWorkspace: Group %d capabilities: %d", groupID, e.Capabilities) log.Debugf("ExtWorkspace: Group %d capabilities: %d", groupID, e.Capabilities)
@@ -171,11 +163,9 @@ func (m *Manager) handleWorkspaceGroup(e ext_workspace.ExtWorkspaceManagerV1Work
log.Debugf("ExtWorkspace: Group %d workspace enter (workspace=%d)", groupID, workspaceID) log.Debugf("ExtWorkspace: Group %d workspace enter (workspace=%d)", groupID, workspaceID)
m.post(func() { m.post(func() {
m.workspacesMutex.Lock() if ws, ok := m.workspaces.Load(workspaceID); ok {
if ws, exists := m.workspaces[workspaceID]; exists {
ws.groupID = groupID ws.groupID = groupID
} }
m.workspacesMutex.Unlock()
group.workspaceIDs = append(group.workspaceIDs, workspaceID) group.workspaceIDs = append(group.workspaceIDs, workspaceID)
m.updateState() m.updateState()
@@ -187,11 +177,9 @@ func (m *Manager) handleWorkspaceGroup(e ext_workspace.ExtWorkspaceManagerV1Work
log.Debugf("ExtWorkspace: Group %d workspace leave (workspace=%d)", groupID, workspaceID) log.Debugf("ExtWorkspace: Group %d workspace leave (workspace=%d)", groupID, workspaceID)
m.post(func() { m.post(func() {
m.workspacesMutex.Lock() if ws, ok := m.workspaces.Load(workspaceID); ok {
if ws, exists := m.workspaces[workspaceID]; exists {
ws.groupID = 0 ws.groupID = 0
} }
m.workspacesMutex.Unlock()
for i, id := range group.workspaceIDs { for i, id := range group.workspaceIDs {
if id == workspaceID { if id == workspaceID {
@@ -209,9 +197,7 @@ func (m *Manager) handleWorkspaceGroup(e ext_workspace.ExtWorkspaceManagerV1Work
m.post(func() { m.post(func() {
group.removed = true group.removed = true
m.groupsMutex.Lock() m.groups.Delete(groupID)
delete(m.groups, groupID)
m.groupsMutex.Unlock()
m.wlMutex.Lock() m.wlMutex.Lock()
handle.Destroy() handle.Destroy()
@@ -234,9 +220,7 @@ func (m *Manager) handleWorkspace(e ext_workspace.ExtWorkspaceManagerV1Workspace
coordinates: make([]uint32, 0), coordinates: make([]uint32, 0),
} }
m.workspacesMutex.Lock() m.workspaces.Store(workspaceID, ws)
m.workspaces[workspaceID] = ws
m.workspacesMutex.Unlock()
handle.SetIdHandler(func(e ext_workspace.ExtWorkspaceHandleV1IdEvent) { handle.SetIdHandler(func(e ext_workspace.ExtWorkspaceHandleV1IdEvent) {
log.Debugf("ExtWorkspace: Workspace %d id: %s", workspaceID, e.Id) log.Debugf("ExtWorkspace: Workspace %d id: %s", workspaceID, e.Id)
@@ -290,9 +274,7 @@ func (m *Manager) handleWorkspace(e ext_workspace.ExtWorkspaceManagerV1Workspace
m.post(func() { m.post(func() {
ws.removed = true ws.removed = true
m.workspacesMutex.Lock() m.workspaces.Delete(workspaceID)
delete(m.workspaces, workspaceID)
m.workspacesMutex.Unlock()
m.wlMutex.Lock() m.wlMutex.Lock()
handle.Destroy() handle.Destroy()
@@ -304,23 +286,21 @@ func (m *Manager) handleWorkspace(e ext_workspace.ExtWorkspaceManagerV1Workspace
} }
func (m *Manager) updateState() { func (m *Manager) updateState() {
m.groupsMutex.RLock()
m.workspacesMutex.RLock()
groups := make([]*WorkspaceGroup, 0) groups := make([]*WorkspaceGroup, 0)
for _, group := range m.groups { m.groups.Range(func(key uint32, group *workspaceGroupState) bool {
if group.removed { if group.removed {
continue return true
} }
outputs := make([]string, 0) outputs := make([]string, 0)
for outputID := range group.outputIDs { for outputID := range group.outputIDs {
m.outputsMutex.RLock() if name, ok := m.outputNames.Load(outputID); ok {
name := m.outputNames[outputID] if name != "" {
m.outputsMutex.RUnlock() outputs = append(outputs, name)
if name != "" { } else {
outputs = append(outputs, name) outputs = append(outputs, fmt.Sprintf("output-%d", outputID))
}
} else { } else {
outputs = append(outputs, fmt.Sprintf("output-%d", outputID)) outputs = append(outputs, fmt.Sprintf("output-%d", outputID))
} }
@@ -328,8 +308,11 @@ func (m *Manager) updateState() {
workspaces := make([]*Workspace, 0) workspaces := make([]*Workspace, 0)
for _, wsID := range group.workspaceIDs { for _, wsID := range group.workspaceIDs {
ws, exists := m.workspaces[wsID] ws, exists := m.workspaces.Load(wsID)
if !exists || ws.removed { if !exists {
continue
}
if ws.removed {
continue continue
} }
@@ -351,10 +334,8 @@ func (m *Manager) updateState() {
Workspaces: workspaces, Workspaces: workspaces,
} }
groups = append(groups, groupState) groups = append(groups, groupState)
} return true
})
m.workspacesMutex.RUnlock()
m.groupsMutex.RUnlock()
newState := State{ newState := State{
Groups: groups, Groups: groups,
@@ -397,8 +378,7 @@ func (m *Manager) notifier() {
continue continue
} }
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
select { select {
case ch <- currentState: case ch <- currentState:
default: default:
@@ -418,9 +398,6 @@ func (m *Manager) ActivateWorkspace(groupID, workspaceID string) error {
errChan := make(chan error, 1) errChan := make(chan error, 1)
m.post(func() { m.post(func() {
m.workspacesMutex.RLock()
defer m.workspacesMutex.RUnlock()
var targetGroupID uint32 var targetGroupID uint32
if groupID != "" { if groupID != "" {
var parsedID uint32 var parsedID uint32
@@ -429,9 +406,10 @@ func (m *Manager) ActivateWorkspace(groupID, workspaceID string) error {
} }
} }
for _, ws := range m.workspaces { var found bool
m.workspaces.Range(func(key uint32, ws *workspaceState) bool {
if targetGroupID != 0 && ws.groupID != targetGroupID { if targetGroupID != 0 && ws.groupID != targetGroupID {
continue return true
} }
if ws.workspaceID == workspaceID || ws.name == workspaceID { if ws.workspaceID == workspaceID || ws.name == workspaceID {
m.wlMutex.Lock() m.wlMutex.Lock()
@@ -441,11 +419,15 @@ func (m *Manager) ActivateWorkspace(groupID, workspaceID string) error {
} }
m.wlMutex.Unlock() m.wlMutex.Unlock()
errChan <- err errChan <- err
return found = true
return false
} }
} return true
})
errChan <- fmt.Errorf("workspace not found: %s in group %s", workspaceID, groupID) if !found {
errChan <- fmt.Errorf("workspace not found: %s in group %s", workspaceID, groupID)
}
}) })
return <-errChan return <-errChan
@@ -455,9 +437,6 @@ func (m *Manager) DeactivateWorkspace(groupID, workspaceID string) error {
errChan := make(chan error, 1) errChan := make(chan error, 1)
m.post(func() { m.post(func() {
m.workspacesMutex.RLock()
defer m.workspacesMutex.RUnlock()
var targetGroupID uint32 var targetGroupID uint32
if groupID != "" { if groupID != "" {
var parsedID uint32 var parsedID uint32
@@ -466,9 +445,10 @@ func (m *Manager) DeactivateWorkspace(groupID, workspaceID string) error {
} }
} }
for _, ws := range m.workspaces { var found bool
m.workspaces.Range(func(key uint32, ws *workspaceState) bool {
if targetGroupID != 0 && ws.groupID != targetGroupID { if targetGroupID != 0 && ws.groupID != targetGroupID {
continue return true
} }
if ws.workspaceID == workspaceID || ws.name == workspaceID { if ws.workspaceID == workspaceID || ws.name == workspaceID {
m.wlMutex.Lock() m.wlMutex.Lock()
@@ -478,11 +458,15 @@ func (m *Manager) DeactivateWorkspace(groupID, workspaceID string) error {
} }
m.wlMutex.Unlock() m.wlMutex.Unlock()
errChan <- err errChan <- err
return found = true
return false
} }
} return true
})
errChan <- fmt.Errorf("workspace not found: %s in group %s", workspaceID, groupID) if !found {
errChan <- fmt.Errorf("workspace not found: %s in group %s", workspaceID, groupID)
}
}) })
return <-errChan return <-errChan
@@ -492,9 +476,6 @@ func (m *Manager) RemoveWorkspace(groupID, workspaceID string) error {
errChan := make(chan error, 1) errChan := make(chan error, 1)
m.post(func() { m.post(func() {
m.workspacesMutex.RLock()
defer m.workspacesMutex.RUnlock()
var targetGroupID uint32 var targetGroupID uint32
if groupID != "" { if groupID != "" {
var parsedID uint32 var parsedID uint32
@@ -503,9 +484,10 @@ func (m *Manager) RemoveWorkspace(groupID, workspaceID string) error {
} }
} }
for _, ws := range m.workspaces { var found bool
m.workspaces.Range(func(key uint32, ws *workspaceState) bool {
if targetGroupID != 0 && ws.groupID != targetGroupID { if targetGroupID != 0 && ws.groupID != targetGroupID {
continue return true
} }
if ws.workspaceID == workspaceID || ws.name == workspaceID { if ws.workspaceID == workspaceID || ws.name == workspaceID {
m.wlMutex.Lock() m.wlMutex.Lock()
@@ -515,11 +497,15 @@ func (m *Manager) RemoveWorkspace(groupID, workspaceID string) error {
} }
m.wlMutex.Unlock() m.wlMutex.Unlock()
errChan <- err errChan <- err
return found = true
return false
} }
} return true
})
errChan <- fmt.Errorf("workspace not found: %s in group %s", workspaceID, groupID) if !found {
errChan <- fmt.Errorf("workspace not found: %s in group %s", workspaceID, groupID)
}
}) })
return <-errChan return <-errChan
@@ -529,10 +515,8 @@ func (m *Manager) CreateWorkspace(groupID, workspaceName string) error {
errChan := make(chan error, 1) errChan := make(chan error, 1)
m.post(func() { m.post(func() {
m.groupsMutex.RLock() var found bool
defer m.groupsMutex.RUnlock() m.groups.Range(func(key uint32, group *workspaceGroupState) bool {
for _, group := range m.groups {
if fmt.Sprintf("group-%d", group.id) == groupID { if fmt.Sprintf("group-%d", group.id) == groupID {
m.wlMutex.Lock() m.wlMutex.Lock()
err := group.handle.CreateWorkspace(workspaceName) err := group.handle.CreateWorkspace(workspaceName)
@@ -541,11 +525,15 @@ func (m *Manager) CreateWorkspace(groupID, workspaceName string) error {
} }
m.wlMutex.Unlock() m.wlMutex.Unlock()
errChan <- err errChan <- err
return found = true
return false
} }
} return true
})
errChan <- fmt.Errorf("workspace group not found: %s", groupID) if !found {
errChan <- fmt.Errorf("workspace group not found: %s", groupID)
}
}) })
return <-errChan return <-errChan
@@ -556,30 +544,27 @@ func (m *Manager) Close() {
m.wg.Wait() m.wg.Wait()
m.notifierWg.Wait() m.notifierWg.Wait()
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true
}) })
m.workspacesMutex.Lock() m.workspaces.Range(func(key uint32, ws *workspaceState) bool {
for _, ws := range m.workspaces {
if ws.handle != nil { if ws.handle != nil {
ws.handle.Destroy() ws.handle.Destroy()
} }
} m.workspaces.Delete(key)
m.workspaces = make(map[uint32]*workspaceState) return true
m.workspacesMutex.Unlock() })
m.groupsMutex.Lock() m.groups.Range(func(key uint32, group *workspaceGroupState) bool {
for _, group := range m.groups {
if group.handle != nil { if group.handle != nil {
group.handle.Destroy() group.handle.Destroy()
} }
} m.groups.Delete(key)
m.groups = make(map[uint32]*workspaceGroupState) return true
m.groupsMutex.Unlock() })
if m.manager != nil { if m.manager != nil {
m.manager.Stop() m.manager.Stop()

View File

@@ -5,6 +5,7 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/proto/ext_workspace" "github.com/AvengeMedia/DankMaterialShell/core/internal/proto/ext_workspace"
wlclient "github.com/AvengeMedia/DankMaterialShell/core/pkg/go-wayland/wayland/client" wlclient "github.com/AvengeMedia/DankMaterialShell/core/pkg/go-wayland/wayland/client"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
) )
type Workspace struct { type Workspace struct {
@@ -37,22 +38,18 @@ type Manager struct {
registry *wlclient.Registry registry *wlclient.Registry
manager *ext_workspace.ExtWorkspaceManagerV1 manager *ext_workspace.ExtWorkspaceManagerV1
outputsMutex sync.RWMutex outputNames syncmap.Map[uint32, string]
outputs map[uint32]*wlclient.Output
outputNames map[uint32]string
groupsMutex sync.RWMutex groups syncmap.Map[uint32, *workspaceGroupState]
groups map[uint32]*workspaceGroupState
workspacesMutex sync.RWMutex workspaces syncmap.Map[uint32, *workspaceState]
workspaces map[uint32]*workspaceState
wlMutex sync.Mutex wlMutex sync.Mutex
cmdq chan cmd cmdq chan cmd
stopChan chan struct{} stopChan chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
subscribers sync.Map subscribers syncmap.Map[string, chan State]
dirty chan struct{} dirty chan struct{}
notifierWg sync.WaitGroup notifierWg sync.WaitGroup
lastNotified *State lastNotified *State
@@ -101,12 +98,9 @@ func (m *Manager) Subscribe(id string) chan State {
} }
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if ch, ok := m.subscribers.LoadAndDelete(id); ok {
if val, ok := m.subscribers.LoadAndDelete(id); ok { close(ch)
close(val.(chan State))
} }
} }
func (m *Manager) notifySubscribers() { func (m *Manager) notifySubscribers() {

View File

@@ -210,14 +210,13 @@ func (m *Manager) Subscribe(id string) chan FreedeskState {
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if val, ok := m.subscribers.LoadAndDelete(id); ok { if val, ok := m.subscribers.LoadAndDelete(id); ok {
close(val.(chan FreedeskState)) close(val)
} }
} }
func (m *Manager) NotifySubscribers() { func (m *Manager) NotifySubscribers() {
state := m.GetState() state := m.GetState()
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan FreedeskState) bool {
ch := value.(chan FreedeskState)
select { select {
case ch <- state: case ch <- state:
default: default:
@@ -227,8 +226,7 @@ func (m *Manager) NotifySubscribers() {
} }
func (m *Manager) Close() { func (m *Manager) Close() {
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan FreedeskState) bool {
ch := value.(chan FreedeskState)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true

View File

@@ -3,6 +3,7 @@ package freedesktop
import ( import (
"sync" "sync"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
) )
@@ -41,5 +42,5 @@ type Manager struct {
accountsObj dbus.BusObject accountsObj dbus.BusObject
settingsObj dbus.BusObject settingsObj dbus.BusObject
currentUID uint64 currentUID uint64
subscribers sync.Map subscribers syncmap.Map[string, chan FreedeskState]
} }

View File

@@ -356,7 +356,7 @@ func (m *Manager) Subscribe(id string) chan SessionState {
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if val, ok := m.subscribers.LoadAndDelete(id); ok { if val, ok := m.subscribers.LoadAndDelete(id); ok {
close(val.(chan SessionState)) close(val)
} }
} }
@@ -389,8 +389,7 @@ func (m *Manager) notifier() {
continue continue
} }
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan SessionState) bool {
ch := value.(chan SessionState)
select { select {
case ch <- currentState: case ch <- currentState:
default: default:
@@ -572,8 +571,7 @@ func (m *Manager) Close() {
m.releaseSleepInhibitor() m.releaseSleepInhibitor()
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan SessionState) bool {
ch := value.(chan SessionState)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true

View File

@@ -163,7 +163,7 @@ func TestManager_Close(t *testing.T) {
assert.False(t, ok2, "ch2 should be closed") assert.False(t, ok2, "ch2 should be closed")
count := 0 count := 0
manager.subscribers.Range(func(key, value interface{}) bool { manager.subscribers.Range(func(key string, ch chan SessionState) bool {
count++ count++
return true return true
}) })

View File

@@ -6,6 +6,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
) )
@@ -50,7 +51,7 @@ type SessionEvent struct {
type Manager struct { type Manager struct {
state *SessionState state *SessionState
stateMutex sync.RWMutex stateMutex sync.RWMutex
subscribers sync.Map subscribers syncmap.Map[string, chan SessionState]
stopChan chan struct{} stopChan chan struct{}
conn *dbus.Conn conn *dbus.Conn
sessionPath dbus.ObjectPath sessionPath dbus.ObjectPath

View File

@@ -247,7 +247,7 @@ func TestManager_Subscribe_Unsubscribe(t *testing.T) {
ch := manager.Subscribe("client1") ch := manager.Subscribe("client1")
assert.NotNil(t, ch) assert.NotNil(t, ch)
count := 0 count := 0
manager.subscribers.Range(func(key, value interface{}) bool { manager.subscribers.Range(func(key string, ch chan NetworkState) bool {
count++ count++
return true return true
}) })
@@ -257,7 +257,7 @@ func TestManager_Subscribe_Unsubscribe(t *testing.T) {
t.Run("unsubscribe removes channel", func(t *testing.T) { t.Run("unsubscribe removes channel", func(t *testing.T) {
manager.Unsubscribe("client1") manager.Unsubscribe("client1")
count := 0 count := 0
manager.subscribers.Range(func(key, value interface{}) bool { count++; return true }) manager.subscribers.Range(func(key string, ch chan NetworkState) bool { count++; return true })
assert.Equal(t, 0, count) assert.Equal(t, 0, count)
}) })

View File

@@ -68,10 +68,8 @@ func NewManager() (*Manager, error) {
}, },
stateMutex: sync.RWMutex{}, stateMutex: sync.RWMutex{},
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1), dirty: make(chan struct{}, 1),
credentialSubscribers: make(map[string]chan CredentialPrompt),
credSubMutex: sync.RWMutex{},
} }
broker := NewSubscriptionBroker(m.broadcastCredentialPrompt) broker := NewSubscriptionBroker(m.broadcastCredentialPrompt)
@@ -275,37 +273,30 @@ func (m *Manager) Subscribe(id string) chan NetworkState {
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if val, ok := m.subscribers.LoadAndDelete(id); ok { if val, ok := m.subscribers.LoadAndDelete(id); ok {
close(val.(chan NetworkState)) close(val)
} }
} }
func (m *Manager) SubscribeCredentials(id string) chan CredentialPrompt { func (m *Manager) SubscribeCredentials(id string) chan CredentialPrompt {
ch := make(chan CredentialPrompt, 16) ch := make(chan CredentialPrompt, 16)
m.credSubMutex.Lock() m.credentialSubscribers.Store(id, ch)
m.credentialSubscribers[id] = ch
m.credSubMutex.Unlock()
return ch return ch
} }
func (m *Manager) UnsubscribeCredentials(id string) { func (m *Manager) UnsubscribeCredentials(id string) {
m.credSubMutex.Lock() if ch, ok := m.credentialSubscribers.LoadAndDelete(id); ok {
if ch, ok := m.credentialSubscribers[id]; ok {
close(ch) close(ch)
delete(m.credentialSubscribers, id)
} }
m.credSubMutex.Unlock()
} }
func (m *Manager) broadcastCredentialPrompt(prompt CredentialPrompt) { func (m *Manager) broadcastCredentialPrompt(prompt CredentialPrompt) {
m.credSubMutex.RLock() m.credentialSubscribers.Range(func(key string, ch chan CredentialPrompt) bool {
defer m.credSubMutex.RUnlock()
for _, ch := range m.credentialSubscribers {
select { select {
case ch <- prompt: case ch <- prompt:
default: default:
} }
} return true
})
} }
func (m *Manager) notifier() { func (m *Manager) notifier() {
@@ -337,8 +328,7 @@ func (m *Manager) notifier() {
continue continue
} }
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan NetworkState) bool {
ch := value.(chan NetworkState)
select { select {
case ch <- currentState: case ch <- currentState:
default: default:
@@ -384,8 +374,7 @@ func (m *Manager) Close() {
m.backend.Close() m.backend.Close()
} }
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan NetworkState) bool {
ch := value.(chan NetworkState)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true

View File

@@ -114,7 +114,7 @@ func TestManager_Close(t *testing.T) {
assert.False(t, ok2, "ch2 should be closed") assert.False(t, ok2, "ch2 should be closed")
count := 0 count := 0
manager.subscribers.Range(func(key, value interface{}) bool { count++; return true }) manager.subscribers.Range(func(key string, ch chan NetworkState) bool { count++; return true })
assert.Equal(t, 0, count) assert.Equal(t, 0, count)
} }

View File

@@ -3,37 +3,29 @@ package network
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"github.com/AvengeMedia/DankMaterialShell/core/internal/errdefs" "github.com/AvengeMedia/DankMaterialShell/core/internal/errdefs"
"github.com/AvengeMedia/DankMaterialShell/core/internal/log" "github.com/AvengeMedia/DankMaterialShell/core/internal/log"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
) )
type SubscriptionBroker struct { type SubscriptionBroker struct {
mu sync.RWMutex pending syncmap.Map[string, chan PromptReply]
pending map[string]chan PromptReply requests syncmap.Map[string, PromptRequest]
requests map[string]PromptRequest pathSettingToToken syncmap.Map[string, string]
pathSettingToToken map[string]string
broadcastPrompt func(CredentialPrompt) broadcastPrompt func(CredentialPrompt)
} }
func NewSubscriptionBroker(broadcastPrompt func(CredentialPrompt)) PromptBroker { func NewSubscriptionBroker(broadcastPrompt func(CredentialPrompt)) PromptBroker {
return &SubscriptionBroker{ return &SubscriptionBroker{
pending: make(map[string]chan PromptReply), broadcastPrompt: broadcastPrompt,
requests: make(map[string]PromptRequest),
pathSettingToToken: make(map[string]string),
broadcastPrompt: broadcastPrompt,
} }
} }
func (b *SubscriptionBroker) Ask(ctx context.Context, req PromptRequest) (string, error) { func (b *SubscriptionBroker) Ask(ctx context.Context, req PromptRequest) (string, error) {
pathSettingKey := fmt.Sprintf("%s:%s", req.ConnectionPath, req.SettingName) pathSettingKey := fmt.Sprintf("%s:%s", req.ConnectionPath, req.SettingName)
b.mu.Lock() if existingToken, alreadyPending := b.pathSettingToToken.Load(pathSettingKey); alreadyPending {
existingToken, alreadyPending := b.pathSettingToToken[pathSettingKey]
b.mu.Unlock()
if alreadyPending {
log.Infof("[SubscriptionBroker] Duplicate prompt for %s, returning existing token", pathSettingKey) log.Infof("[SubscriptionBroker] Duplicate prompt for %s, returning existing token", pathSettingKey)
return existingToken, nil return existingToken, nil
} }
@@ -44,11 +36,9 @@ func (b *SubscriptionBroker) Ask(ctx context.Context, req PromptRequest) (string
} }
replyChan := make(chan PromptReply, 1) replyChan := make(chan PromptReply, 1)
b.mu.Lock() b.pending.Store(token, replyChan)
b.pending[token] = replyChan b.requests.Store(token, req)
b.requests[token] = req b.pathSettingToToken.Store(pathSettingKey, token)
b.pathSettingToToken[pathSettingKey] = token
b.mu.Unlock()
if b.broadcastPrompt != nil { if b.broadcastPrompt != nil {
prompt := CredentialPrompt{ prompt := CredentialPrompt{
@@ -71,10 +61,7 @@ func (b *SubscriptionBroker) Ask(ctx context.Context, req PromptRequest) (string
} }
func (b *SubscriptionBroker) Wait(ctx context.Context, token string) (PromptReply, error) { func (b *SubscriptionBroker) Wait(ctx context.Context, token string) (PromptReply, error) {
b.mu.RLock() replyChan, exists := b.pending.Load(token)
replyChan, exists := b.pending[token]
b.mu.RUnlock()
if !exists { if !exists {
return PromptReply{}, fmt.Errorf("unknown token: %s", token) return PromptReply{}, fmt.Errorf("unknown token: %s", token)
} }
@@ -93,10 +80,7 @@ func (b *SubscriptionBroker) Wait(ctx context.Context, token string) (PromptRepl
} }
func (b *SubscriptionBroker) Resolve(token string, reply PromptReply) error { func (b *SubscriptionBroker) Resolve(token string, reply PromptReply) error {
b.mu.RLock() replyChan, exists := b.pending.Load(token)
replyChan, exists := b.pending[token]
b.mu.RUnlock()
if !exists { if !exists {
log.Warnf("[SubscriptionBroker] Resolve: unknown or expired token: %s", token) log.Warnf("[SubscriptionBroker] Resolve: unknown or expired token: %s", token)
return fmt.Errorf("unknown or expired token: %s", token) return fmt.Errorf("unknown or expired token: %s", token)
@@ -112,25 +96,19 @@ func (b *SubscriptionBroker) Resolve(token string, reply PromptReply) error {
} }
func (b *SubscriptionBroker) cleanup(token string) { func (b *SubscriptionBroker) cleanup(token string) {
b.mu.Lock() if req, exists := b.requests.Load(token); exists {
defer b.mu.Unlock()
if req, exists := b.requests[token]; exists {
pathSettingKey := fmt.Sprintf("%s:%s", req.ConnectionPath, req.SettingName) pathSettingKey := fmt.Sprintf("%s:%s", req.ConnectionPath, req.SettingName)
delete(b.pathSettingToToken, pathSettingKey) b.pathSettingToToken.Delete(pathSettingKey)
} }
delete(b.pending, token) b.pending.Delete(token)
delete(b.requests, token) b.requests.Delete(token)
} }
func (b *SubscriptionBroker) Cancel(path string, setting string) error { func (b *SubscriptionBroker) Cancel(path string, setting string) error {
pathSettingKey := fmt.Sprintf("%s:%s", path, setting) pathSettingKey := fmt.Sprintf("%s:%s", path, setting)
b.mu.Lock() token, exists := b.pathSettingToToken.Load(pathSettingKey)
token, exists := b.pathSettingToToken[pathSettingKey]
b.mu.Unlock()
if !exists { if !exists {
log.Infof("[SubscriptionBroker] Cancel: no pending prompt for %s", pathSettingKey) log.Infof("[SubscriptionBroker] Cancel: no pending prompt for %s", pathSettingKey)
return nil return nil

View File

@@ -3,6 +3,7 @@ package network
import ( import (
"sync" "sync"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
) )
@@ -108,13 +109,12 @@ type Manager struct {
backend Backend backend Backend
state *NetworkState state *NetworkState
stateMutex sync.RWMutex stateMutex sync.RWMutex
subscribers sync.Map subscribers syncmap.Map[string, chan NetworkState]
stopChan chan struct{} stopChan chan struct{}
dirty chan struct{} dirty chan struct{}
notifierWg sync.WaitGroup notifierWg sync.WaitGroup
lastNotifiedState *NetworkState lastNotifiedState *NetworkState
credentialSubscribers map[string]chan CredentialPrompt credentialSubscribers syncmap.Map[string, chan CredentialPrompt]
credSubMutex sync.RWMutex
} }
type EventType string type EventType string

View File

@@ -28,6 +28,7 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/wayland" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/wayland"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/wlcontext" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/wlcontext"
"github.com/AvengeMedia/DankMaterialShell/core/internal/server/wlroutput" "github.com/AvengeMedia/DankMaterialShell/core/internal/server/wlroutput"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
) )
const APIVersion = 18 const APIVersion = 18
@@ -59,8 +60,8 @@ var wlrOutputManager *wlroutput.Manager
var evdevManager *evdev.Manager var evdevManager *evdev.Manager
var wlContext *wlcontext.SharedContext var wlContext *wlcontext.SharedContext
var capabilitySubscribers sync.Map var capabilitySubscribers syncmap.Map[string, chan ServerInfo]
var cupsSubscribers sync.Map var cupsSubscribers syncmap.Map[string, bool]
var cupsSubscriberCount atomic.Int32 var cupsSubscriberCount atomic.Int32
func getSocketDir() string { func getSocketDir() string {
@@ -434,8 +435,7 @@ func getServerInfo() ServerInfo {
func notifyCapabilityChange() { func notifyCapabilityChange() {
info := getServerInfo() info := getServerInfo()
capabilitySubscribers.Range(func(key, value interface{}) bool { capabilitySubscribers.Range(func(key string, ch chan ServerInfo) bool {
ch := value.(chan ServerInfo)
select { select {
case ch <- info: case ch <- info:
default: default:

View File

@@ -26,7 +26,6 @@ func NewManager(display *wlclient.Display, config Config) (*Manager, error) {
config: config, config: config,
display: display, display: display,
ctx: display.Context(), ctx: display.Context(),
outputs: make(map[uint32]*outputState),
cmdq: make(chan cmd, 128), cmdq: make(chan cmd, 128),
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
updateTrigger: make(chan struct{}, 1), updateTrigger: make(chan struct{}, 1),
@@ -114,17 +113,17 @@ func (m *Manager) waylandActor() {
} }
func (m *Manager) allOutputsReady() bool { func (m *Manager) allOutputsReady() bool {
m.outputsMutex.RLock() hasOutputs := false
defer m.outputsMutex.RUnlock() allReady := true
if len(m.outputs) == 0 { m.outputs.Range(func(key uint32, value *outputState) bool {
return false hasOutputs = true
} if value.rampSize == 0 || value.failed {
for _, o := range m.outputs { allReady = false
if o.rampSize == 0 || o.failed {
return false return false
} }
} return true
return true })
return hasOutputs && allReady
} }
func (m *Manager) setupDBusMonitor() error { func (m *Manager) setupDBusMonitor() error {
@@ -157,7 +156,6 @@ func (m *Manager) setupRegistry() error {
m.registry = registry m.registry = registry
outputs := make([]*wlclient.Output, 0) outputs := make([]*wlclient.Output, 0)
outputRegNames := make(map[uint32]uint32)
outputNames := make(map[uint32]string) outputNames := make(map[uint32]string)
var gammaMgr *wlr_gamma_control.ZwlrGammaControlManagerV1 var gammaMgr *wlr_gamma_control.ZwlrGammaControlManagerV1
@@ -198,14 +196,9 @@ func (m *Manager) setupRegistry() error {
if gammaMgr != nil { if gammaMgr != nil {
outputs = append(outputs, output) outputs = append(outputs, output)
outputRegNames[outputID] = e.Name
} }
m.outputsMutex.Lock() m.outputRegNames.Store(outputID, e.Name)
if m.outputRegNames != nil {
m.outputRegNames[outputID] = e.Name
}
m.outputsMutex.Unlock()
m.configMutex.RLock() m.configMutex.RLock()
enabled := m.config.Enabled enabled := m.config.Enabled
@@ -236,23 +229,33 @@ func (m *Manager) setupRegistry() error {
registry.SetGlobalRemoveHandler(func(e wlclient.RegistryGlobalRemoveEvent) { registry.SetGlobalRemoveHandler(func(e wlclient.RegistryGlobalRemoveEvent) {
m.post(func() { m.post(func() {
m.outputsMutex.Lock() var foundID uint32
defer m.outputsMutex.Unlock() var foundOut *outputState
m.outputs.Range(func(id uint32, out *outputState) bool {
for id, out := range m.outputs {
if out.registryName == e.Name { if out.registryName == e.Name {
log.Infof("Output %d (registry name %d) removed, destroying gamma control", id, e.Name) foundID = id
if out.gammaControl != nil { foundOut = out
control := out.gammaControl.(*wlr_gamma_control.ZwlrGammaControlV1) return false
control.Destroy() }
} return true
delete(m.outputs, id) })
if len(m.outputs) == 0 { if foundOut != nil {
m.controlsInitialized = false log.Infof("Output %d (registry name %d) removed, destroying gamma control", foundID, e.Name)
log.Info("All outputs removed, controls no longer initialized") if foundOut.gammaControl != nil {
} control := foundOut.gammaControl.(*wlr_gamma_control.ZwlrGammaControlV1)
return control.Destroy()
}
m.outputs.Delete(foundID)
hasOutputs := false
m.outputs.Range(func(key uint32, value *outputState) bool {
hasOutputs = true
return false
})
if !hasOutputs {
m.controlsInitialized = false
log.Info("All outputs removed, controls no longer initialized")
} }
} }
}) })
@@ -292,7 +295,6 @@ func (m *Manager) setupRegistry() error {
m.gammaControl = gammaMgr m.gammaControl = gammaMgr
m.availableOutputs = physicalOutputs m.availableOutputs = physicalOutputs
m.outputRegNames = outputRegNames
log.Info("setupRegistry: completed successfully (gamma controls will be initialized when enabled)") log.Info("setupRegistry: completed successfully (gamma controls will be initialized when enabled)")
return nil return nil
@@ -308,9 +310,12 @@ func (m *Manager) setupOutputControls(outputs []*wlclient.Output, manager *wlr_g
continue continue
} }
outputID := output.ID()
registryName, _ := m.outputRegNames.Load(outputID)
outState := &outputState{ outState := &outputState{
id: output.ID(), id: outputID,
registryName: m.outputRegNames[output.ID()], registryName: registryName,
output: output, output: output,
gammaControl: control, gammaControl: control,
isVirtual: false, isVirtual: false,
@@ -318,14 +323,12 @@ func (m *Manager) setupOutputControls(outputs []*wlclient.Output, manager *wlr_g
func(state *outputState) { func(state *outputState) {
control.SetGammaSizeHandler(func(e wlr_gamma_control.ZwlrGammaControlV1GammaSizeEvent) { control.SetGammaSizeHandler(func(e wlr_gamma_control.ZwlrGammaControlV1GammaSizeEvent) {
m.outputsMutex.Lock() if outState, exists := m.outputs.Load(state.id); exists {
if outState, exists := m.outputs[state.id]; exists {
outState.rampSize = e.Size outState.rampSize = e.Size
outState.failed = false outState.failed = false
outState.retryCount = 0 outState.retryCount = 0
log.Infof("Output %d gamma_size=%d", state.id, e.Size) log.Infof("Output %d gamma_size=%d", state.id, e.Size)
} }
m.outputsMutex.Unlock()
m.transitionMutex.RLock() m.transitionMutex.RLock()
currentTemp := m.currentTemp currentTemp := m.currentTemp
@@ -337,8 +340,7 @@ func (m *Manager) setupOutputControls(outputs []*wlclient.Output, manager *wlr_g
}) })
control.SetFailedHandler(func(e wlr_gamma_control.ZwlrGammaControlV1FailedEvent) { control.SetFailedHandler(func(e wlr_gamma_control.ZwlrGammaControlV1FailedEvent) {
m.outputsMutex.Lock() if outState, exists := m.outputs.Load(state.id); exists {
if outState, exists := m.outputs[state.id]; exists {
outState.failed = true outState.failed = true
outState.rampSize = 0 outState.rampSize = 0
outState.retryCount++ outState.retryCount++
@@ -357,13 +359,10 @@ func (m *Manager) setupOutputControls(outputs []*wlclient.Output, manager *wlr_g
}) })
}) })
} }
m.outputsMutex.Unlock()
}) })
}(outState) }(outState)
m.outputsMutex.Lock() m.outputs.Store(outputID, outState)
m.outputs[output.ID()] = outState
m.outputsMutex.Unlock()
} }
return nil return nil
@@ -375,8 +374,7 @@ func (m *Manager) addOutputControl(output *wlclient.Output) error {
var outputName string var outputName string
output.SetNameHandler(func(ev wlclient.OutputNameEvent) { output.SetNameHandler(func(ev wlclient.OutputNameEvent) {
outputName = ev.Name outputName = ev.Name
m.outputsMutex.Lock() if outState, exists := m.outputs.Load(outputID); exists {
if outState, exists := m.outputs[outputID]; exists {
outState.name = ev.Name outState.name = ev.Name
if len(ev.Name) >= 9 && ev.Name[:9] == "HEADLESS-" { if len(ev.Name) >= 9 && ev.Name[:9] == "HEADLESS-" {
log.Infof("Detected virtual output %d (name=%s), marking for gamma control skip", outputID, ev.Name) log.Infof("Detected virtual output %d (name=%s), marking for gamma control skip", outputID, ev.Name)
@@ -384,7 +382,6 @@ func (m *Manager) addOutputControl(output *wlclient.Output) error {
outState.failed = true outState.failed = true
} }
} }
m.outputsMutex.Unlock()
}) })
gammaMgr := m.gammaControl.(*wlr_gamma_control.ZwlrGammaControlManagerV1) gammaMgr := m.gammaControl.(*wlr_gamma_control.ZwlrGammaControlManagerV1)
@@ -394,24 +391,24 @@ func (m *Manager) addOutputControl(output *wlclient.Output) error {
return fmt.Errorf("failed to get gamma control: %w", err) return fmt.Errorf("failed to get gamma control: %w", err)
} }
registryName, _ := m.outputRegNames.Load(outputID)
outState := &outputState{ outState := &outputState{
id: outputID, id: outputID,
name: outputName, name: outputName,
registryName: m.outputRegNames[outputID], registryName: registryName,
output: output, output: output,
gammaControl: control, gammaControl: control,
isVirtual: false, isVirtual: false,
} }
control.SetGammaSizeHandler(func(e wlr_gamma_control.ZwlrGammaControlV1GammaSizeEvent) { control.SetGammaSizeHandler(func(e wlr_gamma_control.ZwlrGammaControlV1GammaSizeEvent) {
m.outputsMutex.Lock() if out, exists := m.outputs.Load(outState.id); exists {
if out, exists := m.outputs[outState.id]; exists {
out.rampSize = e.Size out.rampSize = e.Size
out.failed = false out.failed = false
out.retryCount = 0 out.retryCount = 0
log.Infof("Output %d gamma_size=%d", outState.id, e.Size) log.Infof("Output %d gamma_size=%d", outState.id, e.Size)
} }
m.outputsMutex.Unlock()
m.transitionMutex.RLock() m.transitionMutex.RLock()
currentTemp := m.currentTemp currentTemp := m.currentTemp
@@ -423,8 +420,7 @@ func (m *Manager) addOutputControl(output *wlclient.Output) error {
}) })
control.SetFailedHandler(func(e wlr_gamma_control.ZwlrGammaControlV1FailedEvent) { control.SetFailedHandler(func(e wlr_gamma_control.ZwlrGammaControlV1FailedEvent) {
m.outputsMutex.Lock() if out, exists := m.outputs.Load(outState.id); exists {
if out, exists := m.outputs[outState.id]; exists {
out.failed = true out.failed = true
out.rampSize = 0 out.rampSize = 0
out.retryCount++ out.retryCount++
@@ -443,12 +439,9 @@ func (m *Manager) addOutputControl(output *wlclient.Output) error {
}) })
}) })
} }
m.outputsMutex.Unlock()
}) })
m.outputsMutex.Lock() m.outputs.Store(outputID, outState)
m.outputs[output.ID()] = outState
m.outputsMutex.Unlock()
log.Infof("Added gamma control for output %d", output.ID()) log.Infof("Added gamma control for output %d", output.ID())
return nil return nil
@@ -623,17 +616,19 @@ func (m *Manager) transitionWorker() {
if !enabled && targetTemp == identityTemp && m.controlsInitialized { if !enabled && targetTemp == identityTemp && m.controlsInitialized {
m.post(func() { m.post(func() {
log.Info("Destroying gamma controls after transition to identity") log.Info("Destroying gamma controls after transition to identity")
m.outputsMutex.Lock() m.outputs.Range(func(id uint32, out *outputState) bool {
for id, out := range m.outputs {
if out.gammaControl != nil { if out.gammaControl != nil {
control := out.gammaControl.(*wlr_gamma_control.ZwlrGammaControlV1) control := out.gammaControl.(*wlr_gamma_control.ZwlrGammaControlV1)
control.Destroy() control.Destroy()
log.Debugf("Destroyed gamma control for output %d", id) log.Debugf("Destroyed gamma control for output %d", id)
} }
} return true
m.outputs = make(map[uint32]*outputState) })
m.outputs.Range(func(key uint32, value *outputState) bool {
m.outputs.Delete(key)
return true
})
m.controlsInitialized = false m.controlsInitialized = false
m.outputsMutex.Unlock()
m.transitionMutex.Lock() m.transitionMutex.Lock()
m.currentTemp = identityTemp m.currentTemp = identityTemp
@@ -661,9 +656,7 @@ func (m *Manager) recreateOutputControl(out *outputState) error {
return nil return nil
} }
m.outputsMutex.RLock() _, exists := m.outputs.Load(out.id)
_, exists := m.outputs[out.id]
m.outputsMutex.RUnlock()
if !exists { if !exists {
return nil return nil
@@ -689,14 +682,12 @@ func (m *Manager) recreateOutputControl(out *outputState) error {
state := out state := out
control.SetGammaSizeHandler(func(e wlr_gamma_control.ZwlrGammaControlV1GammaSizeEvent) { control.SetGammaSizeHandler(func(e wlr_gamma_control.ZwlrGammaControlV1GammaSizeEvent) {
m.outputsMutex.Lock() if outState, exists := m.outputs.Load(state.id); exists {
if outState, exists := m.outputs[state.id]; exists {
outState.rampSize = e.Size outState.rampSize = e.Size
outState.failed = false outState.failed = false
outState.retryCount = 0 outState.retryCount = 0
log.Infof("Output %d gamma_size=%d (recreated)", state.id, e.Size) log.Infof("Output %d gamma_size=%d (recreated)", state.id, e.Size)
} }
m.outputsMutex.Unlock()
m.transitionMutex.RLock() m.transitionMutex.RLock()
currentTemp := m.currentTemp currentTemp := m.currentTemp
@@ -708,8 +699,7 @@ func (m *Manager) recreateOutputControl(out *outputState) error {
}) })
control.SetFailedHandler(func(e wlr_gamma_control.ZwlrGammaControlV1FailedEvent) { control.SetFailedHandler(func(e wlr_gamma_control.ZwlrGammaControlV1FailedEvent) {
m.outputsMutex.Lock() if outState, exists := m.outputs.Load(state.id); exists {
if outState, exists := m.outputs[state.id]; exists {
outState.failed = true outState.failed = true
outState.rampSize = 0 outState.rampSize = 0
outState.retryCount++ outState.retryCount++
@@ -728,7 +718,6 @@ func (m *Manager) recreateOutputControl(out *outputState) error {
}) })
}) })
} }
m.outputsMutex.Unlock()
}) })
out.gammaControl = control out.gammaControl = control
@@ -750,13 +739,11 @@ func (m *Manager) applyNowOnActor(temp int) {
return return
} }
// Lock while snapshotting outputs to prevent races with recreateOutputControl
m.outputsMutex.RLock()
var outs []*outputState var outs []*outputState
for _, out := range m.outputs { m.outputs.Range(func(key uint32, value *outputState) bool {
outs = append(outs, out) outs = append(outs, value)
} return true
m.outputsMutex.RUnlock() })
if len(outs) == 0 { if len(outs) == 0 {
return return
@@ -796,20 +783,17 @@ func (m *Manager) applyNowOnActor(temp int) {
if err := m.setGammaBytesActor(j.out, j.data); err != nil { if err := m.setGammaBytesActor(j.out, j.data); err != nil {
log.Warnf("Failed to set gamma for output %d: %v", j.out.id, err) log.Warnf("Failed to set gamma for output %d: %v", j.out.id, err)
outID := j.out.id outID := j.out.id
m.outputsMutex.Lock() if out, exists := m.outputs.Load(outID); exists {
if out, exists := m.outputs[outID]; exists {
out.failed = true out.failed = true
out.rampSize = 0 out.rampSize = 0
} }
m.outputsMutex.Unlock()
time.AfterFunc(300*time.Millisecond, func() { time.AfterFunc(300*time.Millisecond, func() {
m.post(func() { m.post(func() {
m.outputsMutex.RLock() if out, exists := m.outputs.Load(outID); exists {
out, exists := m.outputs[outID] if out.failed {
m.outputsMutex.RUnlock() m.recreateOutputControl(out)
if exists && out.failed { }
m.recreateOutputControl(out)
} }
}) })
}) })
@@ -943,8 +927,7 @@ func (m *Manager) notifier() {
continue continue
} }
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
select { select {
case ch <- currentState: case ch <- currentState:
default: default:
@@ -1290,17 +1273,19 @@ func (m *Manager) SetEnabled(enabled bool) {
if currentTemp == identityTemp { if currentTemp == identityTemp {
m.post(func() { m.post(func() {
log.Infof("Already at %dK, destroying gamma controls immediately", identityTemp) log.Infof("Already at %dK, destroying gamma controls immediately", identityTemp)
m.outputsMutex.Lock() m.outputs.Range(func(id uint32, out *outputState) bool {
for id, out := range m.outputs {
if out.gammaControl != nil { if out.gammaControl != nil {
control := out.gammaControl.(*wlr_gamma_control.ZwlrGammaControlV1) control := out.gammaControl.(*wlr_gamma_control.ZwlrGammaControlV1)
control.Destroy() control.Destroy()
log.Debugf("Destroyed gamma control for output %d", id) log.Debugf("Destroyed gamma control for output %d", id)
} }
} return true
m.outputs = make(map[uint32]*outputState) })
m.outputs.Range(func(key uint32, value *outputState) bool {
m.outputs.Delete(key)
return true
})
m.controlsInitialized = false m.controlsInitialized = false
m.outputsMutex.Unlock()
m.transitionMutex.Lock() m.transitionMutex.Lock()
m.currentTemp = identityTemp m.currentTemp = identityTemp
@@ -1326,21 +1311,22 @@ func (m *Manager) Close() {
m.wg.Wait() m.wg.Wait()
m.notifierWg.Wait() m.notifierWg.Wait()
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true
}) })
m.outputsMutex.Lock() m.outputs.Range(func(key uint32, out *outputState) bool {
for _, out := range m.outputs {
if control, ok := out.gammaControl.(*wlr_gamma_control.ZwlrGammaControlV1); ok { if control, ok := out.gammaControl.(*wlr_gamma_control.ZwlrGammaControlV1); ok {
control.Destroy() control.Destroy()
} }
} return true
m.outputs = make(map[uint32]*outputState) })
m.outputsMutex.Unlock() m.outputs.Range(func(key uint32, value *outputState) bool {
m.outputs.Delete(key)
return true
})
if manager, ok := m.gammaControl.(*wlr_gamma_control.ZwlrGammaControlManagerV1); ok { if manager, ok := m.gammaControl.(*wlr_gamma_control.ZwlrGammaControlManagerV1); ok {
manager.Destroy() manager.Destroy()

View File

@@ -7,6 +7,7 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/errdefs" "github.com/AvengeMedia/DankMaterialShell/core/internal/errdefs"
wlclient "github.com/AvengeMedia/DankMaterialShell/core/pkg/go-wayland/wayland/client" wlclient "github.com/AvengeMedia/DankMaterialShell/core/pkg/go-wayland/wayland/client"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
) )
@@ -48,9 +49,8 @@ type Manager struct {
registry *wlclient.Registry registry *wlclient.Registry
gammaControl interface{} gammaControl interface{}
availableOutputs []*wlclient.Output availableOutputs []*wlclient.Output
outputRegNames map[uint32]uint32 outputRegNames syncmap.Map[uint32, uint32]
outputs map[uint32]*outputState outputs syncmap.Map[uint32, *outputState]
outputsMutex sync.RWMutex
controlsInitialized bool controlsInitialized bool
cmdq chan cmd cmdq chan cmd
@@ -69,7 +69,7 @@ type Manager struct {
cachedIPLon *float64 cachedIPLon *float64
locationMutex sync.RWMutex locationMutex sync.RWMutex
subscribers sync.Map subscribers syncmap.Map[string, chan State]
dirty chan struct{} dirty chan struct{}
notifierWg sync.WaitGroup notifierWg sync.WaitGroup
lastNotified *State lastNotified *State
@@ -152,7 +152,7 @@ func (m *Manager) Subscribe(id string) chan State {
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if val, ok := m.subscribers.LoadAndDelete(id); ok { if val, ok := m.subscribers.LoadAndDelete(id); ok {
close(val.(chan State)) close(val)
} }
} }

View File

@@ -155,8 +155,7 @@ func (m *Manager) ApplyConfiguration(heads []HeadConfig, test bool) error {
}) })
headsByName := make(map[string]*headState) headsByName := make(map[string]*headState)
m.heads.Range(func(key, value interface{}) bool { m.heads.Range(func(key uint32, head *headState) bool {
head := value.(*headState)
if !head.finished { if !head.finished {
headsByName[head.name] = head headsByName[head.name] = head
} }
@@ -188,14 +187,13 @@ func (m *Manager) ApplyConfiguration(heads []HeadConfig, test bool) error {
} }
if headCfg.ModeID != nil { if headCfg.ModeID != nil {
val, exists := m.modes.Load(*headCfg.ModeID) mode, exists := m.modes.Load(*headCfg.ModeID)
if !exists { if !exists {
config.Destroy() config.Destroy()
resultChan <- fmt.Errorf("mode not found: %d", *headCfg.ModeID) resultChan <- fmt.Errorf("mode not found: %d", *headCfg.ModeID)
return return
} }
mode := val.(*modeState)
if err := headConfig.SetMode(mode.handle); err != nil { if err := headConfig.SetMode(mode.handle); err != nil {
config.Destroy() config.Destroy()

View File

@@ -274,8 +274,7 @@ func (m *Manager) handleMode(headID uint32, e wlr_output_management.ZwlrOutputHe
m.modes.Store(modeID, mode) m.modes.Store(modeID, mode)
if val, ok := m.heads.Load(headID); ok { if head, ok := m.heads.Load(headID); ok {
head := val.(*headState)
head.modeIDs = append(head.modeIDs, modeID) head.modeIDs = append(head.modeIDs, modeID)
m.heads.Store(headID, head) m.heads.Store(headID, head)
} }
@@ -324,8 +323,7 @@ func (m *Manager) handleMode(headID uint32, e wlr_output_management.ZwlrOutputHe
func (m *Manager) updateState() { func (m *Manager) updateState() {
outputs := make([]Output, 0) outputs := make([]Output, 0)
m.heads.Range(func(key, value interface{}) bool { m.heads.Range(func(key uint32, head *headState) bool {
head := value.(*headState)
if head.finished { if head.finished {
return true return true
} }
@@ -334,11 +332,10 @@ func (m *Manager) updateState() {
var currentMode *OutputMode var currentMode *OutputMode
for _, modeID := range head.modeIDs { for _, modeID := range head.modeIDs {
val, exists := m.modes.Load(modeID) mode, exists := m.modes.Load(modeID)
if !exists { if !exists {
continue continue
} }
mode := val.(*modeState)
if mode.finished { if mode.finished {
continue continue
} }
@@ -439,8 +436,7 @@ func (m *Manager) notifier() {
continue continue
} }
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
select { select {
case ch <- currentState: case ch <- currentState:
default: default:
@@ -461,15 +457,13 @@ func (m *Manager) Close() {
m.wg.Wait() m.wg.Wait()
m.notifierWg.Wait() m.notifierWg.Wait()
m.subscribers.Range(func(key, value interface{}) bool { m.subscribers.Range(func(key string, ch chan State) bool {
ch := value.(chan State)
close(ch) close(ch)
m.subscribers.Delete(key) m.subscribers.Delete(key)
return true return true
}) })
m.modes.Range(func(key, value interface{}) bool { m.modes.Range(func(key uint32, mode *modeState) bool {
mode := value.(*modeState)
if mode.handle != nil { if mode.handle != nil {
mode.handle.Release() mode.handle.Release()
} }
@@ -477,8 +471,7 @@ func (m *Manager) Close() {
return true return true
}) })
m.heads.Range(func(key, value interface{}) bool { m.heads.Range(func(key uint32, head *headState) bool {
head := value.(*headState)
if head.handle != nil { if head.handle != nil {
head.handle.Release() head.handle.Release()
} }

View File

@@ -5,6 +5,7 @@ import (
"github.com/AvengeMedia/DankMaterialShell/core/internal/proto/wlr_output_management" "github.com/AvengeMedia/DankMaterialShell/core/internal/proto/wlr_output_management"
wlclient "github.com/AvengeMedia/DankMaterialShell/core/pkg/go-wayland/wayland/client" wlclient "github.com/AvengeMedia/DankMaterialShell/core/pkg/go-wayland/wayland/client"
"github.com/AvengeMedia/DankMaterialShell/core/pkg/syncmap"
) )
type OutputMode struct { type OutputMode struct {
@@ -49,8 +50,8 @@ type Manager struct {
registry *wlclient.Registry registry *wlclient.Registry
manager *wlr_output_management.ZwlrOutputManagerV1 manager *wlr_output_management.ZwlrOutputManagerV1
heads sync.Map // map[uint32]*headState heads syncmap.Map[uint32, *headState]
modes sync.Map // map[uint32]*modeState modes syncmap.Map[uint32, *modeState]
serial uint32 serial uint32
@@ -59,7 +60,7 @@ type Manager struct {
stopChan chan struct{} stopChan chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
subscribers sync.Map subscribers syncmap.Map[string, chan State]
dirty chan struct{} dirty chan struct{}
notifierWg sync.WaitGroup notifierWg sync.WaitGroup
lastNotified *State lastNotified *State
@@ -125,7 +126,7 @@ func (m *Manager) Subscribe(id string) chan State {
func (m *Manager) Unsubscribe(id string) { func (m *Manager) Unsubscribe(id string) {
if val, ok := m.subscribers.LoadAndDelete(id); ok { if val, ok := m.subscribers.LoadAndDelete(id); ok {
close(val.(chan State)) close(val)
} }

28
core/pkg/syncmap/LICENSE Normal file
View File

@@ -0,0 +1,28 @@
Copyright 2009 The Go Authors.
Copyright 2024 Zachary Olstein.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

537
core/pkg/syncmap/syncmap.go Normal file
View File

@@ -0,0 +1,537 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syncmap
import (
"sync"
"sync/atomic"
"unsafe"
)
// Map is like a Go map[K]V but is safe for concurrent use
// by multiple goroutines without additional locking or coordination.
// Loads, stores, and deletes run in amortized constant time.
//
// The Map type is specialized. Most code should use a plain Go map instead,
// with separate locking or coordination, for better type safety and to make it
// easier to maintain other invariants along with the map content.
//
// The Map type is optimized for two common use cases: (1) when the entry for a given
// key is only ever written once but read many times, as in caches that only grow,
// or (2) when multiple goroutines read, write, and overwrite entries for disjoint
// sets of keys. In these two cases, use of a Map may significantly reduce lock
// contention compared to a Go map paired with a separate [Mutex] or [RWMutex].
//
// The zero Map is empty and ready for use. A Map must not be copied after first use.
//
// In the terminology of [the Go memory model], Map arranges that a write operation
// “synchronizes before” any read operation that observes the effect of the write, where
// read and write operations are defined as follows.
// [Map.Load], [Map.LoadAndDelete], [Map.LoadOrStore], and [Map.Swap] are read operations;
// [Map.Delete], [Map.LoadAndDelete], [Map.Store], and [Map.Swap] are write operations;
// [Map.LoadOrStore] is a write operation when it returns loaded set to false.
//
// [the Go memory model]: https://go.dev/ref/mem
type Map[K comparable, V any] struct {
mu sync.Mutex
// read contains the portion of the map's contents that are safe for
// concurrent access (with or without mu held).
//
// The read field itself is always safe to load, but must only be stored with
// mu held.
//
// Entries stored in read may be updated concurrently without mu, but updating
// a previously-expunged entry requires that the entry be copied to the dirty
// map and unexpunged with mu held.
read atomic.Pointer[readOnly[K, V]]
// dirty contains the portion of the map's contents that require mu to be
// held. To ensure that the dirty map can be promoted to the read map quickly,
// it also includes all of the non-expunged entries in the read map.
//
// Expunged entries are not stored in the dirty map. An expunged entry in the
// clean map must be unexpunged and added to the dirty map before a new value
// can be stored to it.
//
// If the dirty map is nil, the next write to the map will initialize it by
// making a shallow copy of the clean map, omitting stale entries.
dirty map[K]*entry[V]
// misses counts the number of loads since the read map was last updated that
// needed to lock mu to determine whether the key was present.
//
// Once enough misses have occurred to cover the cost of copying the dirty
// map, the dirty map will be promoted to the read map (in the unamended
// state) and the next store to the map will make a new dirty copy.
misses int
}
// readOnly is an immutable struct stored atomically in the Map.read field.
type readOnly[K comparable, V any] struct {
m map[K]*entry[V]
amended bool // true if the dirty map contains some key not in m.
}
// expunged is an arbitrary pointer that marks entries which have been deleted
// from the dirty map.
// Because the same expunged pointer is used regardless of the Map's value type,
// value pointers read from the map must be compared against expunged BEFORE
// casting the pointer to *V.
var expunged = unsafe.Pointer(new(int))
// An entry is a slot in the map corresponding to a particular key.
type entry[V any] struct {
// p points to the value stored for the entry.
//
// If p == nil, the entry has been deleted, and either m.dirty == nil or
// m.dirty[key] is e.
//
// If p == expunged, the entry has been deleted, m.dirty != nil, and the entry
// is missing from m.dirty.
//
// Otherwise, the entry is valid and recorded in m.read.m[key] and, if m.dirty
// != nil, in m.dirty[key].
//
// If p != expunged, it is always safe to cast it to (*V).
//
// An entry can be deleted by atomic replacement with nil: when m.dirty is
// next created, it will atomically replace nil with expunged and leave
// m.dirty[key] unset.
//
// An entry's associated value can be updated by atomic replacement, provided
// p != expunged. If p == expunged, an entry's associated value can be updated
// only after first setting m.dirty[key] = e so that lookups using the dirty
// map find the entry.
p unsafe.Pointer
}
func newEntry[V any](i V) *entry[V] {
e := &entry[V]{}
atomic.StorePointer(&e.p, unsafe.Pointer(&i))
return e
}
func (m *Map[K, V]) loadReadOnly() readOnly[K, V] {
if p := m.read.Load(); p != nil {
return *p
}
return readOnly[K, V]{}
}
// Load returns the value stored in the map for a key, or nil if no
// value is present.
// The ok result indicates whether value was found in the map.
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
read := m.loadReadOnly()
e, ok := read.m[key]
if !ok && read.amended {
m.mu.Lock()
// Avoid reporting a spurious miss if m.dirty got promoted while we were
// blocked on m.mu. (If further loads of the same key will not miss, it's
// not worth copying the dirty map for this key.)
read = m.loadReadOnly()
e, ok = read.m[key]
if !ok && read.amended {
e, ok = m.dirty[key]
// Regardless of whether the entry was present, record a miss: this key
// will take the slow path until the dirty map is promoted to the read
// map.
m.missLocked()
}
m.mu.Unlock()
}
if !ok {
return value, false
}
return e.load()
}
func (e *entry[V]) load() (value V, ok bool) {
p := atomic.LoadPointer(&e.p)
if p == nil || p == expunged {
return value, false
}
return *(*V)(p), true
}
// Store sets the value for a key.
func (m *Map[K, V]) Store(key K, value V) {
_, _ = m.Swap(key, value)
}
// unexpungeLocked ensures that the entry is not marked as expunged.
//
// If the entry was previously expunged, it must be added to the dirty map
// before m.mu is unlocked.
func (e *entry[V]) unexpungeLocked() (wasExpunged bool) {
return atomic.CompareAndSwapPointer(&e.p, expunged, nil)
}
// swapLocked unconditionally swaps a value into the entry.
//
// The entry must be known not to be expunged.
func (e *entry[V]) swapLocked(i *V) *V {
return (*V)(atomic.SwapPointer(&e.p, unsafe.Pointer(i)))
}
// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
// Avoid locking if it's a clean hit.
read := m.loadReadOnly()
if e, ok := read.m[key]; ok {
actual, loaded, ok := e.tryLoadOrStore(value)
if ok {
return actual, loaded
}
}
m.mu.Lock()
read = m.loadReadOnly()
if e, ok := read.m[key]; ok {
if e.unexpungeLocked() {
m.dirty[key] = e
}
actual, loaded, _ = e.tryLoadOrStore(value)
} else if e, ok := m.dirty[key]; ok {
actual, loaded, _ = e.tryLoadOrStore(value)
m.missLocked()
} else {
if !read.amended {
// We're adding the first new key to the dirty map.
// Make sure it is allocated and mark the read-only map as incomplete.
m.dirtyLocked()
m.read.Store(&readOnly[K, V]{m: read.m, amended: true})
}
m.dirty[key] = newEntry(value)
actual, loaded = value, false
}
m.mu.Unlock()
return actual, loaded
}
// tryLoadOrStore atomically loads or stores a value if the entry is not
// expunged.
//
// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and
// returns with ok==false.
func (e *entry[V]) tryLoadOrStore(i V) (actual V, loaded, ok bool) {
ptr := atomic.LoadPointer(&e.p)
if ptr == expunged {
return actual, false, false
}
p := (*V)(ptr)
if p != nil {
return *p, true, true
}
// Copy the interface after the first load to make this method more amenable
// to escape analysis: if we hit the "load" path or the entry is expunged, we
// shouldn't bother heap-allocating.
ic := i
for {
if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) {
return i, false, true
}
ptr = atomic.LoadPointer(&e.p)
if ptr == expunged {
return actual, false, false
}
p = (*V)(ptr)
if p != nil {
return *p, true, true
}
}
}
// LoadAndDelete deletes the value for a key, returning the previous value if any.
// The loaded result reports whether the key was present.
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
read := m.loadReadOnly()
e, ok := read.m[key]
if !ok && read.amended {
m.mu.Lock()
read = m.loadReadOnly()
e, ok = read.m[key]
if !ok && read.amended {
e, ok = m.dirty[key]
delete(m.dirty, key)
// Regardless of whether the entry was present, record a miss: this key
// will take the slow path until the dirty map is promoted to the read
// map.
m.missLocked()
}
m.mu.Unlock()
}
if ok {
return e.delete()
}
return value, false
}
// Delete deletes the value for a key.
func (m *Map[K, V]) Delete(key K) {
m.LoadAndDelete(key)
}
func (e *entry[V]) delete() (value V, ok bool) {
for {
p := atomic.LoadPointer(&e.p)
if p == nil || p == expunged {
return value, false
}
if atomic.CompareAndSwapPointer(&e.p, p, nil) {
return *(*V)(p), true
}
}
}
// trySwap swaps a value if the entry has not been expunged.
//
// If the entry is expunged, trySwap returns false and leaves the entry
// unchanged.
func (e *entry[V]) trySwap(i *V) (*V, bool) {
for {
p := atomic.LoadPointer(&e.p)
if p == expunged {
return nil, false
}
if atomic.CompareAndSwapPointer(&e.p, p, unsafe.Pointer(i)) {
return (*V)(p), true
}
}
}
// Swap swaps the value for a key and returns the previous value if any.
// The loaded result reports whether the key was present.
func (m *Map[K, V]) Swap(key K, value V) (previous V, loaded bool) {
read := m.loadReadOnly()
if e, ok := read.m[key]; ok {
if v, ok := e.trySwap(&value); ok {
if v == nil {
return previous, false
}
return *v, true
}
}
m.mu.Lock()
read = m.loadReadOnly()
if e, ok := read.m[key]; ok {
if e.unexpungeLocked() {
// The entry was previously expunged, which implies that there is a
// non-nil dirty map and this entry is not in it.
m.dirty[key] = e
}
if v := e.swapLocked(&value); v != nil {
loaded = true
previous = *v
}
} else if e, ok := m.dirty[key]; ok {
if v := e.swapLocked(&value); v != nil {
loaded = true
previous = *v
}
} else {
if !read.amended {
// We're adding the first new key to the dirty map.
// Make sure it is allocated and mark the read-only map as incomplete.
m.dirtyLocked()
m.read.Store(&readOnly[K, V]{m: read.m, amended: true})
}
m.dirty[key] = newEntry(value)
}
m.mu.Unlock()
return previous, loaded
}
// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
//
// Range does not necessarily correspond to any consistent snapshot of the Map's
// contents: no key will be visited more than once, but if the value for any key
// is stored or deleted concurrently (including by f), Range may reflect any
// mapping for that key from any point during the Range call. Range does not
// block other methods on the receiver; even f itself may call any method on m.
//
// Range may be O(N) with the number of elements in the map even if f returns
// false after a constant number of calls.
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
// We need to be able to iterate over all of the keys that were already
// present at the start of the call to Range.
// If read.amended is false, then read.m satisfies that property without
// requiring us to hold m.mu for a long time.
read := m.loadReadOnly()
if read.amended {
// m.dirty contains keys not in read.m. Fortunately, Range is already O(N)
// (assuming the caller does not break out early), so a call to Range
// amortizes an entire copy of the map: we can promote the dirty copy
// immediately!
m.mu.Lock()
read = m.loadReadOnly()
if read.amended {
read = readOnly[K, V]{m: m.dirty}
copyRead := read
m.read.Store(&copyRead)
m.dirty = nil
m.misses = 0
}
m.mu.Unlock()
}
for k, e := range read.m {
v, ok := e.load()
if !ok {
continue
}
if !f(k, v) {
break
}
}
}
// CompareAndSwap swaps the old and new values for key
// if the value stored in the map is equal to old.
// The old value must be of a comparable type.
func CompareAndSwap[K comparable, V comparable](m *Map[K, V], key K, old, new V) (swapped bool) {
read := m.loadReadOnly()
if e, ok := read.m[key]; ok {
return tryCompareAndSwap(e, old, new)
} else if !read.amended {
return false // No existing value for key.
}
m.mu.Lock()
defer m.mu.Unlock()
read = m.loadReadOnly()
swapped = false
if e, ok := read.m[key]; ok {
swapped = tryCompareAndSwap(e, old, new)
} else if e, ok := m.dirty[key]; ok {
swapped = tryCompareAndSwap(e, old, new)
// We needed to lock mu in order to load the entry for key,
// and the operation didn't change the set of keys in the map
// (so it would be made more efficient by promoting the dirty
// map to read-only).
// Count it as a miss so that we will eventually switch to the
// more efficient steady state.
m.missLocked()
}
return swapped
}
// CompareAndDelete deletes the entry for key if its value is equal to old.
// The old value must be of a comparable type.
//
// If there is no current value for key in the map, CompareAndDelete
// returns false (even if the old value is the zero value of V).
func CompareAndDelete[K comparable, V comparable](m *Map[K, V], key K, old V) (deleted bool) {
read := m.loadReadOnly()
e, ok := read.m[key]
if !ok && read.amended {
m.mu.Lock()
read = m.loadReadOnly()
e, ok = read.m[key]
if !ok && read.amended {
e, ok = m.dirty[key]
// Don't delete key from m.dirty: we still need to do the “compare” part
// of the operation. The entry will eventually be expunged when the
// dirty map is promoted to the read map.
//
// Regardless of whether the entry was present, record a miss: this key
// will take the slow path until the dirty map is promoted to the read
// map.
m.missLocked()
}
m.mu.Unlock()
}
for ok {
ptr := atomic.LoadPointer(&e.p)
if ptr == nil || ptr == expunged {
return false
}
p := (*V)(ptr)
if *p != old {
return false
}
if atomic.CompareAndSwapPointer(&e.p, ptr, nil) {
return true
}
}
return false
}
// tryCompareAndSwap compare the entry with the given old value and swaps
// it with a new value if the entry is equal to the old value, and the entry
// has not been expunged.
//
// If the entry is expunged, tryCompareAndSwap returns false and leaves
// the entry unchanged.
func tryCompareAndSwap[V comparable](e *entry[V], old, new V) bool {
ptr := atomic.LoadPointer(&e.p)
if ptr == nil || ptr == expunged {
return false
}
p := (*V)(ptr)
if *p != old {
return false
}
// Copy the interface after the first load to make this method more amenable
// to escape analysis: if the comparison fails from the start, we shouldn't
// bother heap-allocating an interface value to store.
nc := new
for {
if atomic.CompareAndSwapPointer(&e.p, ptr, unsafe.Pointer(&nc)) {
return true
}
ptr = atomic.LoadPointer(&e.p)
if ptr == nil || ptr == expunged {
return false
}
p = (*V)(ptr)
if *p != old {
return false
}
}
}
func (m *Map[K, V]) missLocked() {
m.misses++
if m.misses < len(m.dirty) {
return
}
m.read.Store(&readOnly[K, V]{m: m.dirty})
m.dirty = nil
m.misses = 0
}
func (m *Map[K, V]) dirtyLocked() {
if m.dirty != nil {
return
}
read := m.loadReadOnly()
m.dirty = make(map[K]*entry[V], len(read.m))
for k, e := range read.m {
if !e.tryExpungeLocked() {
m.dirty[k] = e
}
}
}
func (e *entry[V]) tryExpungeLocked() (isExpunged bool) {
p := atomic.LoadPointer(&e.p)
for p == nil {
if atomic.CompareAndSwapPointer(&e.p, nil, expunged) {
return true
}
p = atomic.LoadPointer(&e.p)
}
return p == expunged
}