# Quentin Bolsée 2023-10-01
# MIT Center for Bits and Atoms
#
# This work may be reproduced, modified, distributed,
# performed, and displayed for any purpose, but must
# acknowledge this project. Copyright is retained and
# must be preserved. The work is provided as is; no
# warranty is provided, and users accept all liability.
import os.path
import serial
import time
import re
import ast


class MicroPythonError(RuntimeError):
    pass


class Device:
    init_delay = 0.1

    def __init__(self, port, main_filename="main.py"):
        self.ser_obj = serial.Serial(port)
        self.function_args = {}
        self.function_callable = {}
        self.main_filename = main_filename
        self.main_module = main_filename.replace(".py", "")
        # initialize
        self.init()
        self.read_functions()

    def init(self):
        s = self.ser_obj
        s.write(b'\x04')
        s.write(b'\r\n\x03\r\n')
        s.write(b'\x01')
        time.sleep(self.init_delay)
        # clean slate
        s.flushInput()
        # make sure main file is imported
        self.run(f'import {self.main_module}')

    def run(self, cmd, show=False, end="\n"):
        s = self.ser_obj
        s.write((cmd+end).encode("utf-8"))
        s.write(b'\x04')  # ^d reset

        # >OK<RETURN>\x04
        txt_ret = s.read_until(b"\x04")[3:-1].decode("utf-8")

        # <ERROR>\x04
        txt_err = s.read_until(b"\x04")[:-1].decode("utf-8")

        if len(txt_err) > 0:
            raise MicroPythonError(txt_err)

        if show:
            print(f"RETURN: '{txt_ret.rstrip()}'")

        return txt_ret.rstrip()

    def run_func(self, func_name, *args, **kwargs):
        args_list = list(repr(x) for x in args)
        kwargs_list = list(f"{a}={repr(b)}" for a, b in kwargs.items())
        cmd_txt = f"print(repr({self.main_module}.{func_name}({','.join(args_list+kwargs_list)})))"
        ret_txt = self.run(cmd_txt)
        return ast.literal_eval(ret_txt)

    def read_functions(self):
        try:
            self.run(f'f=open("{self.main_filename}","rb")')
        except MicroPythonError:
            raise FileNotFoundError(f"Could not find {self.main_filename} on device!")

        # read main txt file
        main_txt = ast.literal_eval(self.run('print(f.read())')).decode("utf-8")

        # find all functions
        matches = re.finditer(r"def\s+([^(]+)\((.*)\):", main_txt)
        for m in matches:
            name = m.group(1)
            # generate function
            func = lambda *args, func_name=name, **kwargs: self.run_func(func_name, *args, **kwargs)
            self.function_args[name] = m.group(2)
            self.function_callable[name] = func
            setattr(self, name, func)

    def __str__(self):
        txt = f"MicroPython Device at {self.ser_obj.port}, available functions:\n"
        if len(self.function_args) == 0:
            txt += "  none\n"
        else:
            for a, b in self.function_args.items():
                txt += f"  {a}({b})\n"
        return txt[:-1]

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def update(self):
        for name in self.function_args:
            delattr(self, name)
        self.function_args = {}
        self.function_callable = {}
        self.init()
        self.read_functions()

    def upload_main(self, filename):
        self.upload(filename, self.main_filename)

    def upload(self, filename, destination=None, update=True):
        if destination is None:
            _, destination = os.path.split(filename)
        with open(filename, "r") as f:
            file_txt = f.read()
        self.run(f'f = open("{destination}", "wb")')
        self.run(f'f.write({repr(file_txt)})')
        self.run(f'f.close()')
        if update:
            self.update()

    def remove(self, filename):
        self.run('import os')
        self.run(f'os.remove("{filename}")')

    def close(self):
        self.ser_obj.close()