#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-3.0-or-later
# Copyright (C) 2018 Takashi Sakamoto

import sys
import time
import signal

from hinawa_utils.misc.cli_kit import CliKit
from hinawa_utils.efw.efw_unit import EfwUnit


def handle_hardware_info(unit, args):
    print('Mode: {0}'.format(unit.info['model']))
    print('Supported features:')
    for key, value in unit.info['features'].items():
        if value:
            print('    ', key)
    print('Supported signals for clock source:')
    for key, value in unit.info['clock-sources'].items():
        if value:
            print('    ', key)
    print('Supported sampling rates:')
    for key, value in unit.info['sampling-rates'].items():
        if value:
            print('    ', key)
    print('List of physical ports:')
    print('    Inputs:')
    for i, ch in enumerate(unit.info['phys-inputs']):
        print('        {0:02d}: {1}'.format(i, ch))
    print('    Outputs:')
    for i, ch in enumerate(unit.info['phys-outputs']):
        print('        {0:02d}: {1}'.format(i, ch))
    print('The number of channels for mixer:')
    print('    Capture: ', unit.info['capture-channels'])
    print('    Playback:', unit.info['playback-channels'])
    print('The number of PCM channels in a stream:')
    rates = ['44.1/48.0', '88.2/96.0']
    if unit.info['sampling-rates'][176400]:
        rates.append('176.4/192.0')
    print('    tx:')
    for index, rate in enumerate(rates):
        print('{0:>16}: {1}'.format(rate,
                                    unit.info['tx-stream-channels'][index]))
    print('    tx:')
    for index, rate in enumerate(rates):
        print('{0:>16}: {1}'.format(rate,
                                    unit.info['rx-stream-channels'][index]))
    print('Firmware versions:')
    for key, value in unit.info['firmware-versions'].items():
        print('{0:>8}: {1}'.format(key, value))

    return True


def handle_clock_source(unit, args):
    ops = ('set', 'get')
    sources = []
    for key, value in unit.info['clock-sources'].items():
        if value:
            sources.append(key)
    if len(args) >= 1 and args[0] in ops:
        op = args[0]
        rate, source = unit.get_clock_state()
        if len(args) >= 2 and op == 'set' and args[1] in sources:
            source = args[1]
            unit.set_clock_state(rate, source)
            return True
        elif op == 'get':
            print(source)
            return True
    print('Arguments for clock-source command:')
    print('  clock-source OP[SOURCE]')
    print('    OP:        [{0}]'.format('|'.join(ops)))
    print('    SOURCE:    [{0}]'.format('|'.join(sources)))
    return False


def handle_sampling_rate(unit, args):
    ops = ('set', 'get')
    rates = []
    for key, value in unit.info['sampling-rates'].items():
        if value:
            rates.append(key)
    if len(args) >= 1 and args[0] in ops:
        op = args[0]
        rate, source = unit.get_clock_state()
        if len(args) >= 2 and op == 'set' and int(args[1]) in rates:
            rate = int(args[1])
            unit.set_clock_state(rate, source)
            return True
        elif op == 'get':
            print(rate)
            return True
    print('Arguments for sampling-rate command:')
    print('  clock-source OP [SOURCE]')
    print('    OP:        [{0}]'.format('|'.join(ops)))
    print('    SOURCE:    [{0}]'.format('|'.join(map(str, rates))))
    return False


def handle_control_room(unit, args):
    ops = ('set', 'get')
    sources = unit.get_control_room_source_labels()
    if len(args) >= 1 and args[0] in ops:
        op = args[0]
        if len(args) >= 2 and op == 'set' and args[1] in sources:
            source = args[1]
            unit.set_control_room_mirroring(source)
            return True
        elif op == 'get':
            print(unit.get_control_room_mirroring())
            return True
    print('Arguments for control-room command:')
    print('  control-room OP [SOURCE]')
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    SOURCE: [{0}]'.format('|'.join(sources)))
    return False


def handle_digital_input(unit, args):
    ops = ('set', 'get')
    modes = unit.get_digital_input_mode_labels()
    if len(args) >= 1 and args[0] in ops:
        op = args[0]
        if len(args) >= 2 and op == 'set' and args[1] in modes:
            mode = args[1]
            unit.set_digital_input_mode(args[1])
            return True
        elif op == 'get':
            print(unit.get_digital_input_mode())
            return True
    print('Arguments for digital-input command:')
    print('  digital-input OP [MODE]')
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    MODE:   [{0}] if OP=set'.format('|'.join(modes)))
    return False


def handle_phantom_powering(unit, args):
    ops = ('set', 'get')
    if len(args) >= 1 and args[0] in ops:
        op = args[0]
        if len(args) >= 2 and op == ops[0] and int(args[1]) < 2:
            mode = int(args[1])
            unit.set_phantom_powering(mode)
            return True
        else:
            print(unit.get_phantom_powering())
            return True
    print('Arguments for phantom-powering command:')
    print('  phantom-powering OP [ENABLE]')
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    ENABLE: [0|1]')
    return False


def handle_box_state(unit, args):
    states = unit.get_box_states()
    ops = ('set', 'get')
    if len(args) >= 2:
        name, op = args[0:2]
        if name in states and op in ops:
            labels = unit.get_box_state_labels(name)
            if op == 'set' and len(args) >= 3 and args[2] in labels:
                label = args[2]
                unit.set_box_states(name, label)
                return True
            elif op == 'get':
                print(states[name])
                return True
    print('Arguments for box-state command:')
    print('  box-state NAME OP [VALUE]')
    print('    NAME:   [{0}]'.format('|'.join(states)))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    for state in states:
        labels = unit.get_box_state_labels(state)
        print('    VALUE:  [{0}] if NAME={1}'.format('|'.join(labels), state))
    return False


def handle_rx_stream_mapping(unit, args):
    ops = ('set', 'get')
    # Generate stereo pair for outputs
    targets = []
    prev = ''
    num = 0
    for index, name in enumerate(unit.info['phys-outputs']):
        if index % 2 == 0:
            if prev is not name:
                num = 0
            targets.append('{0}-{1}/{2}'.format(name,
                                                num * 2 + 1, num * 2 + 2))
            prev = name
            num += 1
    # Generate stereo pair for sources
    sources = []
    for ch in range(unit.info['rx-stream-channels'][0]):
        if ch % 2 == 0:
            sources.append('stream-{0}/{1}'.format(ch + 1, ch + 2))
    maps = unit.get_stream_mapping()
    if len(args) >= 2:
        target, op = args[0:2]
        if target in targets and op in ops:
            target = targets.index(target)
            if op == 'set' and len(args) >= 3 and args[2] in sources:
                maps['rx-map'][target] = sources.index(args[2])
                unit.set_stream_mapping(maps['rx-map'], maps['tx-map'])
                return True
            elif op == 'get':
                print(sources[maps['rx-map'][target] // 2])
                return True
    print('Arguments for rx-stream-mapping command:')
    print('  rx-stream-mapping TARGETS OP [SOURCE]')
    print('    TARGET: [{0}]'.format('|'.join(targets)))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    SOURCE: [{0}]'.format('|'.join(sources)))
    return False


def handle_tx_stream_mapping(unit, args):
    ops = ('set', 'get')
    # Generate stereo pair for inputs
    inputs = []
    prev = ''
    num = 0
    for index, name in enumerate(unit.info['phys-inputs']):
        if index % 2 == 0:
            if prev is not name:
                num = 0
            inputs.append('{0}-{1}/{2}'.format(name, num * 2 + 1, num * 2 + 2))
            prev = name
            num += 1
    # Generate stereo pair for sinks
    sinks = []
    for ch in range(unit.info['tx-stream-channels'][0]):
        if ch % 2 == 0:
            sinks.append('stream-{0}/{1}'.format(ch + 1, ch + 2))
    maps = unit.get_stream_mapping()
    if len(args) >= 2:
        op, input = args[0:2]
        if op in ops and input in inputs:
            input = inputs.index(input)
            if op == 'set' and len(args) >= 3 and args[2] in sinks:
                sink = sinks.index(args[2]) * 2
                maps['tx-map'][input] = sink
                unit.set_stream_mapping(maps['tx-map'], maps['tx-map'])
                return True
            else:
                print(sinks[maps['tx-map'][input] // 2])
                return True
    print('Arguments for tx-stream-mapping command:')
    print('  tx-stream-mapping OP INPUT [SINK]')
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    INPUT:  [{0}]'.format('|'.join(inputs)))
    print('    SINK:   [{0}]'.format('|'.join(sinks)))
    return False


def handle_output(unit, args):
    items = {
        'gain':     (unit.set_phys_out_gain, unit.get_phys_out_gain),
        'mute':     (unit.set_phys_out_mute, unit.get_phys_out_mute),
    }
    ops = ('set', 'get')
    if unit.info['features']['nominal-output']:
        items['nominal'] = (
            unit.set_phys_out_nominal,
            unit.get_phys_out_nominal,
        )
    if len(args) >= 3:
        item, ch, op = args[0:3]
        ch = int(ch)
        if item in items and ch < len(unit.info['phys-outputs']) and op in ops:
            set_func = items[item][0]
            get_func = items[item][1]
            if op == 'set' and len(args) >= 4:
                if item == 'gain':
                    val = float(args[3])
                else:
                    val = int(args[3])
                set_func(ch, val)
                return True
            elif op == 'get':
                print(get_func(ch))
                return True
    print('Arguments for output command:')
    print('  output ITEM CH OP [dB|ENABLE]')
    print('    ITEM:   [{0}]'.format('|'.join(items)))
    print('    CH:     [0-{0}]'.format(len(unit.info['phys-outputs']) - 1))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    dB:     [-144.0..0.0] if OP=get ITEM=gain')
    print('    ENABLE: [0|1]         if OP=get ITEM=mute')
    if 'nominal' in items:
        print('    ENABLE: [0|1]         if OP=get ITEM=nominal')
    return False


def handle_input(unit, args):
    items = {
        'nominal':  (unit.set_phys_in_nominal, unit.get_phys_in_nominal),
    }
    ops = ('set', 'get')
    if len(args) >= 3:
        item, ch, op = args[0:3]
        ch = int(ch)
        if item in items and ch < len(unit.info['phys-inputs']) and op in ops:
            set_func = items[item][0]
            get_func = items[item][1]
            if len(args) >= 4 and op == 'set':
                val = int(args[3])
                set_func(ch, val)
                return True
            elif op == 'get':
                print(get_func(ch))
                return True
    print('Arguments for input command:')
    print('  input ITEM CH OP [ENABLE]')
    print('    ITEM:   [{0}]'.format('|'.join(items)))
    print('    CH:     [0-{0}]'.format(len(unit.info['phys-inputs']) - 1))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    ENABLE: [0|1] if OP=set')
    return False


def handle_playback(unit, args):
    items = {
        # key: (set_func, get_func, min_val, max_val)
        'gain': (unit.set_playback_gain, unit.get_playback_gain),
        'mute': (unit.set_playback_mute, unit.get_playback_mute),
        'solo': (unit.set_playback_solo, unit.get_playback_solo),
    }
    ops = ('set', 'get')
    if len(args) >= 3:
        item, ch, op = args[0:3]
        ch = int(ch)
        if item in items and ch < unit.info['playback-channels'] and op in ops:
            set_func = items[item][0]
            get_func = items[item][1]
            if op == 'set' and len(args) >= 4:
                if item == 'gain':
                    val = float(args[3])
                else:
                    val = int(args[3])
                set_func(ch, val)
                return True
            elif op == 'get':
                print(get_func(ch))
                return True
    print('Arguments for playback command:')
    print('  playback ITEM CH OP [dB|ENABLE]')
    print('    ITEM:   [{0}]'.format('|'.join(items)))
    print('    CH:     [0-{0}]'.format(unit.info['playback-channels'] - 1))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    dB:     [-144.0..0] if OP=set and ITEM=gain')
    print('    ENABLE: [0|1]       if OP=set and ITEM=mute')
    print('    ENABLE: [0|1]       if OP=set and ITEM=solo')
    return False


def handle_monitor(unit, args):
    items = {
        'gain':  (unit.set_monitor_gain, unit.get_monitor_gain),
        'mute':  (unit.set_monitor_mute, unit.get_monitor_mute),
        'solo':  (unit.set_monitor_solo, unit.get_monitor_solo),
        'pan':   (unit.set_monitor_pan, unit.get_monitor_pan,)
    }
    ops = ('set', 'get')
    if len(args) >= 4:
        item, input, output, op = args[0:4]
        input = int(input)
        output = int(output)
        if (item in items and input < unit.info['capture-channels'] and
                output < unit.info['playback-channels'] and op in ops):
            set_func = items[item][0]
            get_func = items[item][1]
            if op == 'set' and len(args) >= 5:
                if item == 'gain':
                    val = float(args[4])
                else:
                    val = int(args[4])
                set_func(input, output, val)
                return True
            elif op == 'get':
                print(get_func(input, output))
                return True
    print('Arguments for monitor command:')
    print('  monitor ITEM INPUT OUTPUT OP [dB|BALANCE|ENABLE]')
    print('    ITEM:   [{0}]'.format('|'.join(items)))
    print('    INPUT:  [0-{0}]'.format(unit.info['capture-channels'] - 1))
    print('    OUTPUT: [0-{0}]'.format(unit.info['playback-channels'] - 1))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    dB:     [-144.0..0.0]           if OP=set and ITEM=gain')
    print('    ENABLE: [0|1]                   if OP=set and ITEM=mute')
    print('    ENABLE: [0|1]                   if OP=set and ITEM=solo')
    print('    BALANCE:[0..255] (left-right)   if OP=set and ITEM=pan')
    return False


def handle_listen_metering(unit, args):
    # This is handled by another context.
    def handle_unix_signal(signum, frame):
        sys.exit()
    signal.signal(signal.SIGINT, handle_unix_signal)
    while 1:
        meters = unit.get_metering()
        for name in sorted(meters):
            print(name, meters[name])
        print('')
        time.sleep(0.1)
    return True


def get_available_commands(features):
    cmds = {
        'hardware-info':        handle_hardware_info,
        'clock-source':         handle_clock_source,
        'sampling-rate':        handle_sampling_rate,
        'box-state':            handle_box_state,
        'output':               handle_output,
        'playback':             handle_playback,
        'monitor':              handle_monitor,
        'listen-metering':      handle_listen_metering,
    }
    if features['control-room-mirroring']:
        cmds['control-room'] = handle_control_room
    if features['rx-mapping']:
        cmds['rx-stream-mapping'] = handle_rx_stream_mapping
    if features['tx-mapping']:
        cmds['tx-stream-mapping'] = handle_tx_stream_mapping
    if features['phantom-powering']:
        cmds['phantom-powering'] = handle_phantom_powering
    if features['nominal-input']:
        cmds['input'] = handle_input

    count = 0
    for iface in ('spdif-coax', 'spdif-opt', 'aesebu-xlr', 'adat-opt'):
        if features[iface]:
            count += 1
    if count > 1:
        cmds['digital-input'] = handle_digital_input

    return cmds


fullpath = CliKit.seek_snd_unit_path()
if fullpath:
    with EfwUnit(fullpath) as unit:
        cmds = get_available_commands(unit.info['features'])
        CliKit.dispatch_command(unit, cmds)
