"""
MicroPython implementation of the task scheduling system
Based on protothreads concept for cooperative multitasking
"""
import time
from micropython import const

# Constants
TASK_MODE_ONCE = const(0)
TASK_MODE_PERIOD = const(1)

TIMER_MODE_TIMER = const(0)
TIMER_MODE_ONCEROUTINE = const(1)
TIMER_MODE_CYCROUTINE = const(2)

MAX_TIMER = const(5)
TICK_PER_S = const(1000)  # 1000 ticks per second

# Task states
PT_WAITING = const(0)
PT_YIELDED = const(1)
PT_EXITED = const(2)
PT_ENDED = const(3)

class PT:
    """Protothread class"""
    def __init__(self):
        self.lc = 0  # Line counter for the thread

class Timer:
    """Timer class for task scheduling"""
    def __init__(self):
        self.start = 0
        self.interval = 0

    def set(self, interval):
        """Set timer with interval in milliseconds"""
        self.interval = interval
        self.start = time.ticks_ms()

    def expired(self):
        """Check if timer has expired"""
        return time.ticks_diff(time.ticks_ms(), self.start) >= self.interval

class TaskData:
    """Task data structure"""
    def __init__(self, task_id, mode, name, period, task_proc):
        self.active = False
        self.mode = mode
        self.task_id = task_id
        self.name = name
        self.last_time = 0
        self.period = period
        self.task_proc = task_proc
        self.pt = PT()

class TimerData:
    """Timer data structure"""
    def __init__(self):
        self.in_use = False
        self.timer_id = None
        self.mode = 0
        self.period = 0
        self.last_time = 0
        self.routine = None
        self.param = None

class TaskScheduler:
    """Task scheduler implementation"""
    def __init__(self):
        self._tasks = []
        self._timers = [TimerData() for _ in range(MAX_TIMER)]
        self._tmp_1ms_count = 0
        self._all_tick_count = 0
        self._int_cnt = 0
        self._v_int_cnt = 0

    def active_task(self, task_id, period):
        """Activate a task with given period"""
        for task in self._tasks:
            if task.task_id == task_id:
                if not task.active:
                    task.active = True
                    task.period = period
                    task.last_time = 0
                return 0
        return -1

    def active_task_now(self, task_id, period):
        """Activate a task immediately"""
        for task in self._tasks:
            if task.task_id == task_id:
                if not task.active:
                    task.active = True
                    task.period = period
                    task.last_time = 0
                break

    def suspend_task(self, task_id):
        """Suspend a task"""
        for i, task in enumerate(self._tasks):
            if task.task_id == task_id:
                task.active = False
                return i
        return -1

    def set_task_last_time(self, task_id, last_time):
        """Set last execution time for a task"""
        for i, task in enumerate(self._tasks):
            if task.task_id == task_id:
                task.last_time = last_time
                return i
        return -1

    def start_timer(self, mode, period, routine, param=None):
        """Start a timer"""
        for i, timer in enumerate(self._timers):
            if not timer.in_use:
                if mode in (TIMER_MODE_TIMER, TIMER_MODE_ONCEROUTINE, TIMER_MODE_CYCROUTINE):
                    timer.in_use = True
                    timer.mode = mode
                    timer.period = period
                    timer.last_time = time.ticks_ms()
                    timer.routine = routine
                    timer.param = param
                    return i
                return -1
        return -1

    def stop_timer(self, timer_id):
        """Stop a timer"""
        if 0 <= timer_id < MAX_TIMER:
            self._timers[timer_id].in_use = False
            return 0
        return -1

    def stop_all_timers(self):
        """Stop all timers"""
        for timer in self._timers:
            timer.in_use = False

    def add_task(self, task_id, mode, name, period, task_proc):
        """Add a new task to the scheduler"""
        task = TaskData(task_id, mode, name, period, task_proc)
        self._tasks.append(task)

    def task_proc(self):
        """Main task processing function"""
        current_time = time.ticks_ms()
        
        # Process timers
        for timer in self._timers:
            if timer.in_use:
                if time.ticks_diff(current_time, timer.last_time) >= timer.period:
                    if timer.mode == TIMER_MODE_TIMER:
                        timer.last_time = current_time
                    elif timer.mode == TIMER_MODE_ONCEROUTINE:
                        timer.in_use = False
                    elif timer.mode == TIMER_MODE_CYCROUTINE:
                        timer.last_time = current_time
                    
                    if timer.routine:
                        timer.routine(timer.param)

        # Process tasks
        for task in self._tasks:
            if task.active:
                if time.ticks_diff(current_time, task.last_time) >= task.period:
                    result = task.task_proc(task.pt)
                    if result == PT_ENDED and task.mode == TASK_MODE_ONCE:
                        task.active = False
                    task.last_time = current_time

    def get_system_ms(self):
        """Get system time in milliseconds"""
        return time.ticks_ms()

# Example usage:
def example():
    scheduler = TaskScheduler()
    
    # Define a task function
    def my_task(pt):
        # Task implementation using protothread style
        PT_BEGIN = 0
        while True:
            if pt.lc == PT_BEGIN:
                # Task logic here
                pt.lc = 1
                return PT_WAITING
            elif pt.lc == 1:
                # More task logic
                pt.lc = PT_BEGIN
                return PT_ENDED
    
    # Add task to scheduler
    scheduler.add_task(1, TASK_MODE_PERIOD, "MyTask", 1000, my_task)
    
    # Main loop
    while True:
        scheduler.task_proc()
        time.sleep_ms(1)

if __name__ == '__main__':
    example()