require 'logger'

class OpcodeGenerator
  include Enumerable

  OPCODE_SIZE = 2

  def  initialize(source_file)
    @source_file = source_file
  end

  def each
    while true
      text = @source_file.read(OPCODE_SIZE)
      return if text.nil?
      yield text.unpack('S')[0]
    end
  end

  def seek(address)
    @source_file.seek address
  end

  def skip_next
    @source_file.seek(OPCODE_SIZE, IO::SEEK_CUR)
  end

end

class Chip_8_Emulator

  REGISTER_COUNT = 0x16,
  REGISTER_PATTERN = /V([0-9A-F])/
  REGISTER_MASK = 0xFF
  CARRY_MASK = 0x100
  LEFT_SHIFT_MASK = 0x80
  RIGHT_SHIFT_MASK = 0x1
  CARRY_REGISTER = 0xF

  def initialize
    @register = Array.new(REGISTER_COUNT)
  end

  def run(opcode_generator)
    @register.fill 0
    opcode_generator.each do |opcode|
      case opcode & 0xF000
      when 0x0000 # exit gracefully
        $LOG.debug sprintf("0x%04X Exit", opcode)
        break
      when 0x1000 # Jump to NNNN
        address = address_from_opcode opcode
        $LOG.debug sprintf("0x%04X Jump to 0x%04X", opcode, address)
        opcode_generator.seek address
      when 0x3000 # Skip next instruction if VX == KK
        x, k = register_and_constant_from_opcode opcode
        $LOG.debug sprintf("0x%04X Skip if V%X == 0x%04X", opcode, x,
k)
        if @register[x] == k then
          opcode_generator.skip_next
        end
      when 0x6000 # VX = KK
        x, k = register_and_constant_from_opcode opcode
        $LOG.debug sprintf("0x%04X V%X = 0x%04X", opcode, x, k)
        @register[x] = k
      when 0x7000 # VX = VX + KK
        x, k = register_and_constant_from_opcode opcode
        $LOG.debug sprintf("0x%04X V%X = V%X + 0x%04X", opcode, x, x,
k)
        @register[x] = short_int_add(@register[x], k)
      when 0x8000 # register operations
        case opcode & 0x000F
        when 0x0 # VX = VY
          x, y = register_pair_from_opcode opcode
          $LOG.debug sprintf("0x%04X V%X = V%X", opcode, x, y)
          @register[x] = @register[y]
        when 0x1 # VX = VX OR VY
          x, y = register_pair_from_opcode opcode
          $LOG.debug sprintf("0x%04X V%X = V%X OR V%X", opcode, x, x,
y)
          @register[x] |= @register[y]
        when 0x2 # VX = VX AND VY
          x, y = register_pair_from_opcode opcode
          $LOG.debug sprintf("0x%04X V%X = V%X AND V%X", opcode, x, x,
y)
          @register[x] &= @register[y]
        when 0x3 # VX = VX AND VY
          x, y = register_pair_from_opcode opcode
          $LOG.debug sprintf("0x%04X V%X = V%X XOR V%X", opcode, x, x,
y)
          @register[x] ^= @register[y]
        when 0x4 # VX = VX + VY
          x, y = register_pair_from_opcode opcode
          @register[x] = short_int_add(@register[x], @register[y])
        when 0x5 # VX = VX - VY
          x, y = register_pair_from_opcode opcode
          $LOG.debug sprintf("0x%04X V%X = V%X - V%X", opcode, x, x, y)
          @register[x] = short_int_subtract(@register[x], @register[y])
        when 0x6 # VX = VX shift right 1
          x = register_from_opcode opcode
          $LOG.debug sprintf("0x%04X V%X = V%X shift right 1", opcode,
x, x)
          @register[x] = short_int_shift_right(@register[x])
        when 0x7 # VX = VY - VX
          x, y = register_pair_from_opcode opcode
          $LOG.debug sprintf("0x%04X V%X = V%X - V%X", opcode, x, y, x)
          @register[x] = short_int_subtract(@register[y], @register[x])
        when 0xE # VX = VX shift left 1
          x = register_from_opcode opcode
          $LOG.debug sprintf("0x%04X V%X = V%X shift left 1", opcode,
x, x)
          @register[x] = short_int_shift_left(@register[x])
        else
          raise RuntimeError, "Unknown register opcode
0x#{opcode.to_s(16)}"
        end
        register_index = (opcode & 0x0F00) >> 8
        value = opcode & 0x00FF
        $LOG.debug sprintf("0x%04X V%X = V%X + 0x%04X", opcode,
register_index, register_index, value)
      when 0xC000 # VX = Random number AND KK
        x, k = register_and_constant_from_opcode opcode
        r = rand(0xFF)
        $LOG.debug sprintf("0x%04X V%X = random number 0x%04X AND
0x%04X", opcode, x, r, k)
        @register[x] = r & k
      else
        raise RuntimeError, "Unknown opcode 0x#{opcode.to_s(16)}"
      end
    end
  end

  def address_from_opcode(opcode)
    opcode & 0x0FFF
  end

  def register_from_opcode(opcode)
    (opcode & 0x0F00) >> 8
  end

  def register_and_constant_from_opcode(opcode)
    x = (opcode & 0x0F00) >> 8
    k = opcode & 0x00FF

    [x, k]
  end

  def register_pair_from_opcode(opcode)
    x = (opcode & 0x0F00) >> 8
    y = (opcode & 0x00F0) >> 4

    [x, y]
  end

  def short_int_add(a, b)
    sum = a + b
    @register[CARRY_REGISTER] = (sum & CARRY_MASK) >> 8
    sum & REGISTER_MASK
  end

  def short_int_subtract(a, b)
    difference = (a | CARRY_MASK) - b
    @register[CARRY_REGISTER] = (difference & CARRY_MASK) >> 8
    difference & REGISTER_MASK
  end

  def short_int_shift_left(a)
    @register[CARRY_REGISTER] = a & LEFT_SHIFT_MASK
    (a << 1) & REGISTER_MASK
  end

  def short_int_shift_right(a)
    @register[CARRY_REGISTER] = a & RIGHT_SHIFT_MASK
    a >> 1
  end

  def method_missing(method_id)
    match_object = REGISTER_PATTERN.match(method_id.id2name)
    if match_object.nil? : raise NoMethodError, method_id.inspect end

    @register[match_object[1].hex]
  end

end

if __FILE__ == $0

  $LOG = Logger.new STDOUT
  $LOG.info "program starts"

  require 'test/unit'
  require 'stringio'

  class TestOpcodeGenerator < Test::Unit::TestCase

    TEST_DATA = [0x0001, 0x0002, 0x0003]

    def test_generator
      opcodes = OpcodeGenerator.new StringIO.new(TEST_DATA.pack("S*"))
      opcodes.zip(TEST_DATA) {|opcode, test_element| assert_equal
test_element, opcode}
    end

    def test_seek
      opcodes = OpcodeGenerator.new StringIO.new(TEST_DATA.pack("S*"))
      opcodes.seek 2
      opcodes.zip(TEST_DATA[1,TEST_DATA.length-1]) {|opcode,
test_element| assert_equal test_element, opcode}
    end

    def test_skip_next
      opcodes = OpcodeGenerator.new StringIO.new(TEST_DATA.pack("S*"))
      opcodes.seek 2
      opcodes.skip_next
      opcodes.zip(TEST_DATA[2,TEST_DATA.length-2]) {|opcode,
test_element| assert_equal test_element, opcode}
    end

  end

  class TestChip_8 < Test::Unit::TestCase

    # dump of file Chip8Test.html
    TEST_DATA =
[0x6177,0x6245,0x7101,0x8320,0x8121,0x8122,0x8233,0x8134,0x8235,0x8106,0x8327,0x830E,0x64FF,0xC411,0x32BB,0x1000,0x0000,].pack('S*')


    # V1:01000101
    # V2:10111011
    # V3:11101100
    # V4:this number should be random, so do multiple runs to make sure
it changes
    # VF:00000000
    def test_emulator

      emulator = Chip_8_Emulator.new

      emulator.run OpcodeGenerator.new(StringIO.new(TEST_DATA))
      assert_equal 0b01000101, emulator.V1
      assert_equal 0b10111011, emulator.V2
      assert_equal 0b11101100, emulator.V3
      first_v4 = emulator.V4
      assert_equal 0b00000000, emulator.VF

      emulator.run OpcodeGenerator.new(StringIO.new(TEST_DATA))
      assert_equal 0b01000101, emulator.V1
      assert_equal 0b10111011, emulator.V2
      assert_equal 0b11101100, emulator.V3

# note that this test fails sometimes because the domain isn't very big
#      assert_not_equal first_v4, emulator.V4

      assert_equal 0b00000000, emulator.VF
      
    end
  
  end                                     
  
end