#!/usr/bin/env python3
"""
fanmon  —  Fan & Temperature Monitor + GPU Curve Editor
MSI MAG B550 Tomahawk  /  Radeon RX 7900 XT  /  Ryzen 7 5800X3D
"""

import curses, os, sys, time, glob, re, subprocess

# ── GPU OD paths (RX 7900 XT) ─────────────────────────────────────────────────
GPU_DEV      = '/sys/devices/pci0000:00/0000:00:03.1/0000:2b:00.0/0000:2c:00.0/0000:2d:00.0'
FAN_CTRL     = GPU_DEV + '/gpu_od/fan_ctrl'
FAN_CURVE_F  = FAN_CTRL + '/fan_curve'
FAN_ZERO_RPM = FAN_CTRL + '/fan_zero_rpm_enable'
FAN_MIN_PWM  = FAN_CTRL + '/fan_minimum_pwm'
OD_COMMIT    = GPU_DEV  + '/pp_od_clk_voltage'
FANCURVE_SVC = 'fancurve'
FANCURVE_BIN = '/usr/local/bin/fancurve'

TEMP_MIN, TEMP_MAX = 25, 100   # OD range for hotspot
DUTY_MIN, DUTY_MAX = 23, 100   # OD range for fan duty %

# ── Helpers ────────────────────────────────────────────────────────────────────
def ri(p):
    try:
        return int(open(p).read().strip())
    except Exception:
        return None

def syswrite(path, value):
    """Write to sysfs (tries direct, falls back to sudo -n)."""
    try:
        with open(path, 'w') as f:
            f.write(str(value) + '\n')
        return True
    except PermissionError:
        r = subprocess.run(
            ['sudo', '-n', 'tee', path],
            input=str(value) + '\n', text=True, capture_output=True
        )
        return r.returncode == 0
    except Exception:
        return False

def bar(val, lo, hi, width=22):
    pct = max(0.0, min(1.0, (val - lo) / (hi - lo))) if hi != lo else 0.0
    n = round(pct * width)
    return '█' * n + '░' * (width - n)

def temp_attr(t):
    if t < 60: return curses.color_pair(2)
    if t < 80: return curses.color_pair(3)
    return curses.color_pair(4)

def safe(win, y, x, text, attr=0):
    h, w = win.getmaxyx()
    if y < 0 or y >= h - 1 or x < 0 or x >= w:
        return
    clip = max(0, w - x)
    try:
        win.addstr(y, x, text[:clip], attr)
    except curses.error:
        pass

# ── hwmon device ───────────────────────────────────────────────────────────────
class Hwmon:
    def __init__(self, path):
        self.path = path
        self.name = open(path + '/name').read().strip() if os.path.exists(path + '/name') else '?'
        self.temps = {}   # idx → {label, val}
        self.fans  = {}   # idx → {label, rpm}
        self.pwms  = {}   # idx → {duty, mode, rw}

    def refresh(self):
        p = self.path
        self.temps.clear(); self.fans.clear(); self.pwms.clear()

        for f in sorted(glob.glob(p + '/temp*_input')):
            m = re.search(r'temp(\d+)_input$', f)
            if not m: continue
            i = m.group(1)
            v = ri(f)
            if v is None: continue
            lf = p + f'/temp{i}_label'
            label = open(lf).read().strip() if os.path.exists(lf) else f'Temp{i}'
            self.temps[i] = {'label': label, 'val': v / 1000.0}

        for f in sorted(glob.glob(p + '/fan*_input')):
            m = re.search(r'fan(\d+)_input$', f)
            if not m: continue
            i = m.group(1)
            v = ri(f)
            if v is None: continue
            lf = p + f'/fan{i}_label'
            label = open(lf).read().strip() if os.path.exists(lf) else f'Fan{i}'
            self.fans[i] = {'label': label, 'rpm': v}

        for f in sorted(glob.glob(p + '/pwm[0-9]')):
            m = re.search(r'pwm(\d+)$', f)
            if not m: continue
            i = m.group(1)
            v = ri(f)
            if v is None: continue
            ep = p + f'/pwm{i}_enable'
            mode = ri(ep) if os.path.exists(ep) else None
            self.pwms[i] = {'duty': round(v * 100 / 255), 'pwm': v,
                             'mode': mode, 'rw': os.access(f, os.W_OK)}


def discover():
    devs = []
    for p in sorted(glob.glob('/sys/class/hwmon/hwmon*'),
                    key=lambda x: int(re.search(r'\d+$', x).group())):
        if os.path.exists(p + '/name'):
            devs.append(Hwmon(p))
    return devs


# ── GPU OD curve ───────────────────────────────────────────────────────────────
def read_curve():
    try:
        txt = open(FAN_CURVE_F).read()
        pts = [[int(m.group(2)), int(m.group(3))]
               for m in re.finditer(r'(\d+): (\d+)C (\d+)%', txt)]
        if len(pts) == 5:
            return pts
    except Exception:
        pass
    return [[30,23],[55,23],[72,45],[83,65],[95,90]]

def apply_curve(pts):
    """Validate, write to OD, commit, patch fancurve.py. Returns (ok, msg)."""
    for i, (t, d) in enumerate(pts):
        if not (TEMP_MIN <= t <= TEMP_MAX):
            return False, f'Pt{i}: temp {t}°C außerhalb {TEMP_MIN}-{TEMP_MAX}'
        if not (DUTY_MIN <= d <= DUTY_MAX):
            return False, f'Pt{i}: duty {d}% außerhalb {DUTY_MIN}-{DUTY_MAX}'
    for i in range(1, len(pts)):
        if pts[i][0] < pts[i-1][0]:
            return False, f'Temps müssen aufsteigend sein (Punkt {i})'

    syswrite(FAN_ZERO_RPM, 0)
    syswrite(FAN_MIN_PWM, DUTY_MIN)
    for i, (t, d) in enumerate(pts):
        if not syswrite(FAN_CURVE_F, f'{i} {t} {d}'):
            return False, f'Schreibfehler Punkt {i}'
    if not syswrite(OD_COMMIT, 'c'):
        return False, 'Commit fehlgeschlagen'

    _patch_service(pts)
    return True, 'Kurve gespeichert!'

def _patch_service(pts):
    """Update CURVE constant in fancurve.py and restart service."""
    try:
        src = open(FANCURVE_BIN).read()
        new = 'CURVE = [\n' + ''.join(f'    ({t},  {d}),\n' for t,d in pts) + ']'
        patched = re.sub(r'CURVE\s*=\s*\[.*?\]', new, src, flags=re.DOTALL)
        tmp = '/tmp/_fanmon_patch.py'
        open(tmp, 'w').write(patched)
        subprocess.run(['sudo','-n','install','-m','755',tmp,FANCURVE_BIN],
                       capture_output=True, timeout=5)
        subprocess.run(['sudo','-n','systemctl','restart',FANCURVE_SVC],
                       capture_output=True, timeout=8)
    except Exception:
        pass


# ── Colors ─────────────────────────────────────────────────────────────────────
def init_colors():
    curses.start_color()
    curses.use_default_colors()
    curses.init_pair(1, curses.COLOR_CYAN,    -1)  # heading
    curses.init_pair(2, curses.COLOR_GREEN,   -1)  # cool
    curses.init_pair(3, curses.COLOR_YELLOW,  -1)  # warm
    curses.init_pair(4, curses.COLOR_RED,     -1)  # hot
    curses.init_pair(5, curses.COLOR_WHITE,   -1)  # normal
    curses.init_pair(6, curses.COLOR_MAGENTA, -1)  # selected
    curses.init_pair(7, curses.COLOR_BLACK,   curses.COLOR_WHITE)  # inv


# ── Monitor view ───────────────────────────────────────────────────────────────
class Monitor:
    def __init__(self, scr):
        self.scr = scr
        self.devs = discover()
        self.t_refresh = 0
        self.msg = ''; self.msg_t = 0

    def tick(self):
        if time.time() - self.t_refresh >= 3:
            for d in self.devs:
                d.refresh()
            self.t_refresh = time.time()

    def draw(self):
        s = self.scr
        s.erase()
        h, w = s.getmaxyx()
        bw = max(16, min(24, w - 44))

        # ── title bar
        ts = time.strftime('%H:%M:%S')
        hdr = f' fanmon  q:beenden  e:GPU-Kurve bearbeiten  r:aktualisieren  {ts} '
        safe(s, 0, 0, hdr.ljust(w), curses.color_pair(7))
        row = 2

        def heading(txt):
            nonlocal row
            if row < h-1:
                safe(s, row, 1, f'── {txt} ', curses.color_pair(1)|curses.A_BOLD)
                row += 1

        def trow(label, val, lo=0, hi=100, unit='°C'):
            nonlocal row
            if row >= h-2: return
            b = bar(val, lo, hi, bw)
            safe(s, row, 3,  f'{label:<20}', curses.A_NORMAL)
            safe(s, row, 23, f'{val:6.1f}{unit}  ', temp_attr(val)|curses.A_BOLD)
            safe(s, row, 34, b, temp_attr(val))
            row += 1

        def frow(label, rpm, duty=None, extra=''):
            nonlocal row
            if row >= h-2: return
            safe(s, row, 3,  f'{label:<20}', curses.A_NORMAL)
            rpm_s = f'{rpm:5d} RPM' if rpm >= 0 else ' N/A  RPM'
            safe(s, row, 23, rpm_s, curses.color_pair(2 if rpm > 0 else 3))
            if duty is not None:
                b = bar(duty, 0, 100, 12)
                safe(s, row, 34, f'{duty:3d}%  ', curses.color_pair(2))
                safe(s, row, 40, b, curses.color_pair(2))
            if extra:
                safe(s, row, 54, extra, curses.A_DIM)
            row += 1

        # ── GPU
        gpu = next((d for d in self.devs if d.name == 'amdgpu'), None)
        heading('GPU: amdgpu  RX 7900 XT')
        if gpu:
            for i, t in sorted(gpu.temps.items()):
                trow(t['label'], t['val'], 0, 110)
            for i, f in sorted(gpu.fans.items()):
                duty = gpu.pwms.get(i, {}).get('duty')
                frow(f['label'], f['rpm'], duty, '(OD-Kurve aktiv)')
            if row < h-2:
                curve = read_curve()
                cstr = '  '.join(f'{t}°→{d}%' for t,d in curve)
                safe(s, row, 3, 'OD-Kurve: ', curses.A_BOLD)
                safe(s, row, 13, cstr, curses.color_pair(1))
                row += 1
        else:
            safe(s, row, 3, '(amdgpu nicht gefunden)', curses.color_pair(4)); row += 1
        row += 1

        # ── CPU
        cpu = next((d for d in self.devs if d.name == 'k10temp'), None)
        heading('CPU: k10temp  Ryzen 7 5800X3D')
        if cpu:
            for i, t in sorted(cpu.temps.items()):
                trow(t['label'], t['val'], 0, 95)
        else:
            safe(s, row, 3, '(k10temp nicht gefunden)', curses.color_pair(4)); row += 1
        row += 1

        # ── Case / Mainboard (nct6687) — read-only
        mb = next((d for d in self.devs if d.name in ('nct6687','nct6686','nct6683')), None)
        heading('Gehäuse / Mainboard: nct6687  (BIOS-gesteuert — nur Anzeige)')
        if mb:
            for i, t in sorted(mb.temps.items()):
                trow(t['label'], t['val'], 0, 90)
            for i, f in sorted(mb.fans.items()):
                if f['rpm'] == 0: continue  # skip silent/absent headers
                duty = mb.pwms.get(i, {}).get('duty')
                frow(f['label'], f['rpm'], duty, '(BIOS)')
            # also show zero-RPM headers in grey
            for i, f in sorted(mb.fans.items()):
                if f['rpm'] > 0: continue
                safe(s, row, 3, f'{f["label"]:<20}', curses.A_DIM)
                safe(s, row, 23, '    0 RPM', curses.A_DIM)
                row += 1
                if row >= h-2: break
        else:
            safe(s, row, 3, '(nct6687 nicht gefunden)', curses.color_pair(3)); row += 1
        row += 1

        # ── NVMe
        nvmes = [d for d in self.devs if 'nvme' in d.name]
        if nvmes:
            heading('NVMe-Drives')
            for nv in nvmes:
                for i, t in sorted(nv.temps.items()):
                    trow(f'{nv.name} {t["label"]}', t['val'], 0, 80)
            row += 1

        # ── status bar
        age = int(time.time() - self.t_refresh)
        st = f' Aktualisiert vor {age}s '
        if self.msg and (time.time() - self.msg_t) < 5:
            st += f'│ {self.msg} '
        safe(s, h-1, 0, st.ljust(w), curses.color_pair(7))
        s.refresh()

    def run(self):
        curses.curs_set(0)
        self.scr.nodelay(True)
        self.scr.timeout(500)
        init_colors()
        self.tick()
        while True:
            self.tick()
            self.draw()
            k = self.scr.getch()
            if k in (ord('q'), ord('Q')):
                break
            elif k in (ord('r'), ord('R')):
                self.t_refresh = 0
            elif k in (ord('e'), ord('E')):
                self.scr.nodelay(False); self.scr.timeout(-1)
                msg = CurveEditor(self.scr).run()
                if msg:
                    self.msg = msg; self.msg_t = time.time()
                self.scr.nodelay(True); self.scr.timeout(500)
                self.t_refresh = 0


# ── GPU Curve Editor ───────────────────────────────────────────────────────────
class CurveEditor:
    def __init__(self, scr):
        self.scr = scr
        self.pts = read_curve()
        while len(self.pts) < 5:
            self.pts.append([85, 80])
        self.row = 0   # selected point (0-4)
        self.col = 0   # 0=temp  1=duty

    def draw(self):
        s = self.scr
        s.erase()
        h, w = s.getmaxyx()
        init_colors()

        safe(s, 0, 0, ' GPU Fan-Kurve bearbeiten '.center(w), curses.color_pair(7))
        safe(s, 1, 2,
             f'5 Punkte: Hotspot-Temperatur → Lüfter-Duty%   '
             f'Bereich: {TEMP_MIN}-{TEMP_MAX}°C / {DUTY_MIN}-{DUTY_MAX}%',
             curses.A_DIM)

        # column headers
        safe(s, 3, 3,  'Pt', curses.A_BOLD)
        safe(s, 3, 7,  'Temp (°C)', curses.A_BOLD)
        safe(s, 3, 20, 'Duty (%)', curses.A_BOLD)
        safe(s, 3, 32, 'Visualisierung (0-100%)', curses.A_BOLD)
        safe(s, 4, 2, '─' * min(72, w-3), curses.A_DIM)

        bw = max(16, min(32, w - 38))
        for i, (temp, duty) in enumerate(self.pts):
            y = 5 + i
            sel = (i == self.row)
            arrow = '▶' if sel else ' '
            base = curses.color_pair(6)|curses.A_BOLD if sel else curses.A_NORMAL

            safe(s, y, 1, arrow, base)
            safe(s, y, 3, str(i), base)

            t_at = curses.color_pair(7) if (sel and self.col == 0) else base
            d_at = curses.color_pair(7) if (sel and self.col == 1) else base
            safe(s, y, 7,  f'{temp:4d}°C', t_at)
            safe(s, y, 20, f'{duty:3d}%',  d_at)
            safe(s, y, 32, bar(duty, 0, 100, bw), temp_attr(temp))

        safe(s, 11, 2, '─' * min(72, w-3), curses.A_DIM)

        help_txt = [
            '↑↓  Punkt wählen   Tab/←→  Feld wechseln (Temp ↔ Duty)',
            '+/-  ±1    Shift+↑↓  ±5    Zifferntasten: Wert direkt eingeben',
            'Enter  Übernehmen & speichern   r  Standard zurücksetzen   q  Abbrechen',
        ]
        for i, line in enumerate(help_txt):
            safe(s, 13+i, 2, line, curses.A_DIM)

        s.refresh()

    def clamp(self, field, v):
        return max(TEMP_MIN, min(TEMP_MAX, v)) if field == 0 \
               else max(DUTY_MIN, min(DUTY_MAX, v))

    def adjust(self, delta):
        t, d = self.pts[self.row]
        if self.col == 0:
            self.pts[self.row][0] = self.clamp(0, t + delta)
        else:
            self.pts[self.row][1] = self.clamp(1, d + delta)

    def inline_number(self):
        """Inline digit entry for the active cell."""
        s = self.scr
        curses.curs_set(1)
        buf = ''
        y = 5 + self.row
        x = 7 if self.col == 0 else 20
        while True:
            self.draw()
            safe(s, y, x, f'{buf:<6}', curses.color_pair(7))
            s.refresh()
            k = s.getch()
            if k in (10, 13):
                break
            elif k in (curses.KEY_BACKSPACE, 127, 8):
                buf = buf[:-1]
            elif ord('0') <= k <= ord('9') and len(buf) < 4:
                buf += chr(k)
            elif k == 27:
                buf = ''; break
        curses.curs_set(0)
        if buf:
            try:
                v = int(buf)
                if self.col == 0: self.pts[self.row][0] = self.clamp(0, v)
                else:             self.pts[self.row][1] = self.clamp(1, v)
            except ValueError:
                pass

    def run(self):
        curses.curs_set(0)
        init_colors()
        while True:
            self.draw()
            k = self.scr.getch()
            if k in (ord('q'), ord('Q'), 27):
                return None
            elif k in (curses.KEY_UP, ord('k')):
                self.row = (self.row - 1) % 5
            elif k in (curses.KEY_DOWN, ord('j')):
                self.row = (self.row + 1) % 5
            elif k in (curses.KEY_LEFT, ord('h'), curses.KEY_BTAB):
                self.col = 0
            elif k in (curses.KEY_RIGHT, ord('l'), ord('\t')):
                self.col = 1
            elif k in (ord('+'), ord('=')):
                self.adjust(1)
            elif k == ord('-'):
                self.adjust(-1)
            elif k == curses.KEY_SR:   # Shift+Up
                self.adjust(5)
            elif k == curses.KEY_SF:   # Shift+Down
                self.adjust(-5)
            elif ord('0') <= k <= ord('9'):
                # Digit → enter number inline
                buf = chr(k)
                # seed the inline entry
                curses.curs_set(1)
                s = self.scr
                y = 5 + self.row
                x = 7 if self.col == 0 else 20
                while True:
                    self.draw()
                    safe(s, y, x, f'{buf:<6}', curses.color_pair(7))
                    s.refresh()
                    c = s.getch()
                    if c in (10, 13): break
                    elif c in (curses.KEY_BACKSPACE, 127, 8): buf = buf[:-1]
                    elif ord('0') <= c <= ord('9') and len(buf) < 4: buf += chr(c)
                    elif c == 27: buf = ''; break
                curses.curs_set(0)
                if buf:
                    try:
                        v = int(buf)
                        if self.col == 0: self.pts[self.row][0] = self.clamp(0, v)
                        else:             self.pts[self.row][1] = self.clamp(1, v)
                    except ValueError:
                        pass
            elif k in (10, 13, curses.KEY_ENTER):
                ok, msg = apply_curve(self.pts)
                return msg
            elif k in (ord('r'), ord('R')):
                self.pts = [[30,23],[55,23],[72,45],[83,65],[95,90]]


# ── Entry ──────────────────────────────────────────────────────────────────────
def main(stdscr):
    Monitor(stdscr).run()

if __name__ == '__main__':
    try:
        curses.wrapper(main)
    except KeyboardInterrupt:
        pass
