import compiler
import ast
from inspect import getmembers

import pynes.bitbag

class BitArray:
    def __init__(self, lst):
        self.value = []
        for l in lst:
            self.value.append(l.n)

    def list(self):
        return self.value

    def to_asm(self):
        hexes = ["$%02X" % v for v in self.value]
        asm = ''
        for i in range(len(hexes) / 16):
            asm += '  .db ' + ','.join(hexes[i*16:i*16+16]) + '\n'
        if len(asm) > 0:
            return asm
        return False

class BitSprite(BitArray):

    def __init__(self, x, y, attrib=0, tile=0, width=1, height=1):
        self.value = []




class Cartridge:

    def __init__(self):
        self._state = None
        self._asm_chunks = {}

        self.has_reset = False #reset def is found
        self.has_nmi = False #nmi def is found

        self.has_prog = False #has any program
        self.has_bank1 = False #has any attrib def
        self.has_chr = False #has any sprite

        self._header = {'.inesprg':1, '.ineschr':1,
            '.inesmap':0, '.inesmir':1}
        self.sprites = []
        self.nametable = {}
        self._vars = {}
        self.bitpaks = {}
        self._progcode = ""
        self._joypad1 = False

    @property
    def state(self):
        return self._state

    @state.setter
    def state(self, value):
        self._state = value
        self.prog = value + ':\n' 

    def headers(self):
        asm_code = ""
        for h in ['.inesprg', '.ineschr', '.inesmap', '.inesmir']:
            asm_code += h + ' ' + str(self._header[h]) + '\n'
        asm_code += '\n'
        return asm_code

    def boot(self):
        asm_code = "  .org $FFFA\n"
        if self.has_nmi:
            asm_code += '  .dw NMI\n'
        else:
            asm_code += '  .dw 0\n'
        
        if self.has_reset:
            asm_code += '  .dw RESET\n'
        else:
            asm_code += '  .dw 0\n'
        
        asm_code += '  .dw 0\n\n'

        return asm_code

    def init(self):
        return (
          '  SEI          ; disable IRQs\n' +
          '  CLD          ; disable decimal mode\n' +
          '  LDX #$40\n' +
          '  STX $4017    ; disable APU frame IRQ\n' +
          '  LDX #$FF\n' +
          '  TXS          ; Set up stack\n' +
          '  INX          ; now X = 0\n' +
          '  STX $2000    ; disable NMI\n' +
          '  STX $2001    ; disable rendering\n' +
          '  STX $4010    ; disable DMC IRQs\n'
        )

    def rsset(self):
        asm_code = ""
        for v in self._vars:
            if isinstance(self._vars[v], int):
                asm_code += v + ' .rs ' + str(self._vars[v]) + '\n'
        if len(asm_code) > 0:
            return ("  .rsset $0000\n" + asm_code + '\n\n')
        return ""

    def prog(self):
        asm_code = ""
        for bp in self.bitpaks:
            asm_code += self.bitpaks[bp].procedure() + '\n'
        asm_code += self._progcode 
        if len(asm_code) > 0:
            return ("  .bank 0\n  .org $C000\n\n" + asm_code + '\n\n')
        return ""

    def bank1(self):
        asm_code = ""
        for v in self._vars:
            if isinstance(self._vars[v], BitArray) and self._vars[v].to_asm():
                asm_code += v + ':\n' +self._vars[v].to_asm()
        if len(asm_code) > 0:
            return ("  .bank 1\n  .org $E000\n\n" + asm_code + '\n\n')
        return ""

    def nmi(self):
        joypad1_code = ""
        if self._joypad1:
            joypad1_code = (
                "\nJoyPad1Up:\n"
                "  LDA $4016\n"
                "  AND #%00000001\n"
                "  BEQ EndUp\n"
            )
            if 'joypad1_up' in self._asm_chunks:
                joypad1_code += self._asm_chunks['joypad1_up']
            joypad1_code += "EndUp:\n"
        nmi_code = ""
        if len(joypad1_code) > 0:
            nmi_code = (
                "NMI:\n"
                "  LDA #$00\n"
                "  STA $2003 ; Write Only: Sets the offset in sprite ram.\n"
                "  LDA #$02\n"
                "  STA $4014 ; Write Only; DMA\n"
            )
            return nmi_code + joypad1_code + "\n"
        return ""

    def set_var(self, varname, value):
        self._vars[varname] = value

    def get_var(self, varname):
        return self._vars[varname]

    def to_asm(self):
        asm_code = ';Generated by PyNES\n\n'
        asm_code += self.headers()
        asm_code += self.rsset()
        asm_code += self.prog()
        asm_code += self.nmi()
        asm_code += self.bank1()
        asm_code += self.boot()

        print asm_code
        return asm_code


class PyNesVisitor(ast.NodeVisitor):

    def generic_visit(self, node, index = 0):
        for field, value in reversed(list(ast.iter_fields(node))):
            #print value
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AST):
                        self.visit(item)
            elif isinstance(value, ast.AST):
                self.visit(value)

    def visit_Import(self, node):
        pass

    def visit_If(self, node):
        if node.test.comparators[0].s == '__main__':
            pass
        else:
            print 'IF'
            print dir(node.test.comparators[0])
            print node.test.comparators[0].s

    def visit_Assign(self, node):
        global cart
        if (len(node.targets) == 1):
            varname = node.targets[0].id
            if isinstance(node.value, ast.Call):
                call = node.value
                if call.func.id:
                    if call.func.id == 'rs':
                        arg = call.args[0].n
                        cart.set_var(varname, arg)
                elif call.func.value.id == 'pynes' \
                    and node.value.func.attr == 'rsset':
                        #print 'opa rsset'
                        pass
            elif isinstance(node.value, ast.List):
                cart.set_var(varname, BitArray(node.value.elts))
        else:
            raise Exception('dammit')

    def visit_FunctionDef(self, node):
        if node.name in ['reset','nmi']:
            global cart
            cart._progcode += node.name.upper() + ':\n'
            if node.name == 'reset':
                cart.has_reset = True
                cart._progcode += cart.init()
            elif node.name == 'nmi':
                cart.has_nmi = True
            self.generic_visit(node)
        elif node.name[:8] == 'joypad1_':
            cart.has_nmi = True
            cart._joypad1 = True
            cart.state = node.name
            cart._asm_chunks[cart.state] = ""
            cart._asm_chunks[cart.state] += (
                "  LDA py          ; Y position\n"
                "  SEC\n"
                "  SBC #$01        ; Y = Y - 1\n"
                "  STA py\n")
            action = node.name[8:]
            if action == 'a':
                pass
            elif action == 'b':
                pass
            elif action == 'select':
                pass
            elif action == 'start':
                pass
            elif action == 'up':
                pass
                #cart._progcode += 'JoyPad1Up:'
            self.generic_visit(node)




    def visit_Call(self, node):
        global cart
        if node.func.id:
            if node.func.id not in cart.bitpaks:
                obj = getattr(pynes.bitbag, node.func.id, None)
                if (obj):
                    bp = obj()
                    cart.bitpaks[node.func.id] = bp
                    cart._progcode += bp()
            else:
                bp = cart.bitpaks[node.func.id]
                cart._progcode += bp()
        elif node.func.value.id == 'pynes':
            if node.func.attr == 'wait_vblank':
                print 'wait_vblank'
            elif node.func.attr == 'load_sprite':
                print 'load_sprite'

    def visit_Add(self, node):
        #self.generic_visit(node)
        #print node 
        #print node.left
        print 'this is an ADD'

    def visit_Sub(self, node):
        print node

    def visit_BinOp(self, node):
        self.generic_visit(node)
        print 'BinOp'
        #print type(node.left).__name__
        print node.left._fields
        print node.left
        print node.left.n

        a = getmembers(node.left)

    def visit_Name(self, node):
        print node.id + 'oi'

cart = None

def pynes_compiler(code, cartridge = cart):
    global cart
    if cartridge == None:
        cart = cartridge = Cartridge()

    python_land = ast.parse(code)
    turist = PyNesVisitor()
    turist.visit(python_land)
    cart = None
    return cartridge