// -*- mode:c++ -*-

// Copyright (c) 2010 ARM Limited
// All rights reserved
//
// The license below extends only to copyright in the software and shall
// not be construed as granting a license to any other intellectual
// property including but not limited to intellectual property relating
// to a hardware implementation of the functionality of the software
// licensed hereunder.  You may use the software subject to the license
// terms below provided that you ensure that this notice is replicated
// unmodified and in its entirety in all distributions of the software,
// modified or unmodified, in source code or in binary form.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met: redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer;
// redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution;
// neither the name of the copyright holders nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// Authors: Gabe Black

let {{

    header_output = ""
    decoder_output = ""
    exec_output = ""

    calcQCode = '''
        CpsrQ = (resTemp & 1) << 27;
    '''

    calcCcCode = '''
        uint16_t _iz, _in;
        _in = (resTemp >> %(negBit)d) & 1;
        _iz = ((%(zType)s)resTemp == 0);

        CondCodesNZ = (_in << 1) | _iz;

        DPRINTF(Arm, "(in, iz) = (%%d, %%d)\\n", _in, _iz);
       '''

    def buildMultInst(mnem, doCc, unCc, regs, code, flagType):
        global header_output, decoder_output, exec_output
        cCode = carryCode[flagType]
        vCode = overflowCode[flagType]
        zType = "uint32_t"
        negBit = 31
        if flagType == "llbit":
            zType = "uint64_t"
            negBit = 63
        if flagType == "overflow":
            ccCode = calcQCode
        else:
            ccCode = calcCcCode % {
                "negBit": negBit,
                "zType": zType
            }

        if not regs in (3, 4):
            raise Exception, "Multiplication instructions with %d " + \
                             "registers are not implemented"

        if regs == 3:
            base = 'Mult3'
        else:
            base = 'Mult4'

        Name = mnem.capitalize()

        if unCc:
            iop = InstObjParams(mnem, Name, base,
                            {"code" : code,
                             "predicate_test": pickPredicate(code),
                             "op_class": "IntMultOp" })
        if doCc:
            iopCc = InstObjParams(mnem + "s", Name + "Cc", base,
                              {"code" : code + ccCode,
                               "predicate_test": pickPredicate(code + ccCode),
                               "op_class": "IntMultOp" })

        if regs == 3:
            declare = Mult3Declare
            constructor = Mult3Constructor
        else:
            declare = Mult4Declare
            constructor = Mult4Constructor

        if unCc:
            header_output += declare.subst(iop)
            decoder_output += constructor.subst(iop)
            exec_output += PredOpExecute.subst(iop)
        if doCc:
            header_output += declare.subst(iopCc)
            decoder_output += constructor.subst(iopCc)
            exec_output += PredOpExecute.subst(iopCc)

    def buildMult3Inst(mnem, code, flagType = "logic"):
        buildMultInst(mnem, True, True, 3, code, flagType)

    def buildMult3InstCc(mnem, code, flagType = "logic"):
        buildMultInst(mnem, True, False, 3, code, flagType)

    def buildMult3InstUnCc(mnem, code, flagType = "logic"):
        buildMultInst(mnem, False, True, 3, code, flagType)

    def buildMult4Inst(mnem, code, flagType = "logic"):
        buildMultInst(mnem, True, True, 4, code, flagType)

    def buildMult4InstCc(mnem, code, flagType = "logic"):
        buildMultInst(mnem, True, False, 4, code, flagType)

    def buildMult4InstUnCc(mnem, code, flagType = "logic"):
        buildMultInst(mnem, False, True, 4, code, flagType)

    buildMult4Inst    ("mla", "Reg0 = resTemp = Reg1 * Reg2 + Reg3;")
    buildMult4InstUnCc("mls", "Reg0 = resTemp = Reg3 - Reg1 * Reg2;")
    buildMult3Inst    ("mul", "Reg0 = resTemp = Reg1 * Reg2;")
    buildMult4InstCc  ("smlabb", '''Reg0 = resTemp =
                                        sext<16>(bits(Reg1, 15, 0)) *
                                        sext<16>(bits(Reg2_sw, 15, 0)) +
                                        Reg3_sw;
                                 resTemp = bits(resTemp, 32) !=
                                           bits(resTemp, 31);
                                 ''', "overflow")
    buildMult4InstCc  ("smlabt", '''Reg0 = resTemp =
                                        sext<16>(bits(Reg1, 15, 0)) *
                                        sext<16>(bits(Reg2_sw, 31, 16)) +
                                        Reg3_sw;
                                 resTemp = bits(resTemp, 32) !=
                                           bits(resTemp, 31);
                                 ''', "overflow")
    buildMult4InstCc  ("smlatb", '''Reg0 = resTemp =
                                        sext<16>(bits(Reg1, 31, 16)) *
                                        sext<16>(bits(Reg2_sw, 15, 0)) +
                                        Reg3_sw;
                                 resTemp = bits(resTemp, 32) !=
                                           bits(resTemp, 31);
                                 ''', "overflow")
    buildMult4InstCc  ("smlatt", '''Reg0 = resTemp =
                                        sext<16>(bits(Reg1, 31, 16)) *
                                        sext<16>(bits(Reg2_sw, 31, 16)) +
                                        Reg3_sw;
                                 resTemp = bits(resTemp, 32) !=
                                           bits(resTemp, 31);
                                 ''', "overflow")
    buildMult4InstCc  ("smlad", '''Reg0 = resTemp =
                                        sext<16>(bits(Reg1, 31, 16)) *
                                        sext<16>(bits(Reg2, 31, 16)) +
                                        sext<16>(bits(Reg1, 15, 0)) *
                                        sext<16>(bits(Reg2, 15, 0)) +
                                        Reg3_sw;
                                    resTemp = bits(resTemp, 32) !=
                                              bits(resTemp, 31);
                                ''', "overflow")
    buildMult4InstCc  ("smladx", '''Reg0 = resTemp =
                                         sext<16>(bits(Reg1, 31, 16)) *
                                         sext<16>(bits(Reg2, 15, 0)) +
                                         sext<16>(bits(Reg1, 15, 0)) *
                                         sext<16>(bits(Reg2, 31, 16)) +
                                         Reg3_sw;
                                    resTemp = bits(resTemp, 32) !=
                                              bits(resTemp, 31);
                                 ''', "overflow")
    buildMult4Inst    ("smlal", '''resTemp = sext<32>(Reg2) * sext<32>(Reg3) +
                                       (int64_t)((Reg1_ud << 32) | Reg0_ud);
                                   Reg0_ud = (uint32_t)resTemp;
                                   Reg1_ud = (uint32_t)(resTemp >> 32);
                                ''', "llbit")
    buildMult4InstUnCc("smlalbb", '''resTemp = sext<16>(bits(Reg2, 15, 0)) *
                                               sext<16>(bits(Reg3, 15, 0)) +
                                               (int64_t)((Reg1_ud << 32) |
                                                         Reg0_ud);
                                     Reg0_ud = (uint32_t)resTemp;
                                     Reg1_ud = (uint32_t)(resTemp >> 32);
                                  ''')
    buildMult4InstUnCc("smlalbt", '''resTemp = sext<16>(bits(Reg2, 15, 0)) *
                                               sext<16>(bits(Reg3, 31, 16)) +
                                               (int64_t)((Reg1_ud << 32) |
                                                         Reg0_ud);
                                     Reg0_ud = (uint32_t)resTemp;
                                     Reg1_ud = (uint32_t)(resTemp >> 32);
                                  ''')
    buildMult4InstUnCc("smlaltb", '''resTemp = sext<16>(bits(Reg2, 31, 16)) *
                                               sext<16>(bits(Reg3, 15, 0)) +
                                               (int64_t)((Reg1_ud << 32) |
                                                         Reg0_ud);
                                     Reg0_ud = (uint32_t)resTemp;
                                     Reg1_ud = (uint32_t)(resTemp >> 32);
                                  ''')
    buildMult4InstUnCc("smlaltt", '''resTemp = sext<16>(bits(Reg2, 31, 16)) *
                                               sext<16>(bits(Reg3, 31, 16)) +
                                               (int64_t)((Reg1_ud << 32) |
                                                         Reg0_ud);
                                     Reg0_ud = (uint32_t)resTemp;
                                     Reg1_ud = (uint32_t)(resTemp >> 32);
                                  ''')
    buildMult4InstUnCc("smlald", '''resTemp =
                                        sext<16>(bits(Reg2, 31, 16)) *
                                        sext<16>(bits(Reg3, 31, 16)) +
                                        sext<16>(bits(Reg2, 15, 0)) *
                                        sext<16>(bits(Reg3, 15, 0)) +
                                        (int64_t)((Reg1_ud << 32) |
                                                  Reg0_ud);
                                    Reg0_ud = (uint32_t)resTemp;
                                    Reg1_ud = (uint32_t)(resTemp >> 32);
                                 ''')
    buildMult4InstUnCc("smlaldx", '''resTemp =
                                         sext<16>(bits(Reg2, 31, 16)) *
                                         sext<16>(bits(Reg3, 15, 0)) +
                                         sext<16>(bits(Reg2, 15, 0)) *
                                         sext<16>(bits(Reg3, 31, 16)) +
                                         (int64_t)((Reg1_ud << 32) |
                                                   Reg0_ud);
                                     Reg0_ud = (uint32_t)resTemp;
                                     Reg1_ud = (uint32_t)(resTemp >> 32);
                                  ''')
    buildMult4InstCc  ("smlawb", '''Reg0 = resTemp =
                                        (Reg1_sw *
                                         sext<16>(bits(Reg2, 15, 0)) +
                                         ((int64_t)Reg3_sw << 16)) >> 16;
                                    resTemp = bits(resTemp, 32) !=
                                              bits(resTemp, 31);
                                 ''', "overflow")
    buildMult4InstCc  ("smlawt", '''Reg0 = resTemp =
                                        (Reg1_sw *
                                         sext<16>(bits(Reg2, 31, 16)) +
                                         ((int64_t)Reg3_sw << 16)) >> 16;
                                    resTemp = bits(resTemp, 32) !=
                                              bits(resTemp, 31);
                                 ''', "overflow")
    buildMult4InstCc  ("smlsd", '''Reg0 = resTemp =
                                       sext<16>(bits(Reg1, 15, 0)) *
                                       sext<16>(bits(Reg2, 15, 0)) -
                                       sext<16>(bits(Reg1, 31, 16)) *
                                       sext<16>(bits(Reg2, 31, 16)) +
                                       Reg3_sw;
                                    resTemp = bits(resTemp, 32) !=
                                              bits(resTemp, 31);
                                ''', "overflow")
    buildMult4InstCc  ("smlsdx", '''Reg0 = resTemp =
                                        sext<16>(bits(Reg1, 15, 0)) *
                                        sext<16>(bits(Reg2, 31, 16)) -
                                        sext<16>(bits(Reg1, 31, 16)) *
                                        sext<16>(bits(Reg2, 15, 0)) +
                                        Reg3_sw;
                                    resTemp = bits(resTemp, 32) !=
                                              bits(resTemp, 31);
                                 ''', "overflow")
    buildMult4InstUnCc("smlsld", '''resTemp =
                                        sext<16>(bits(Reg2, 15, 0)) *
                                        sext<16>(bits(Reg3, 15, 0)) -
                                        sext<16>(bits(Reg2, 31, 16)) *
                                        sext<16>(bits(Reg3, 31, 16)) +
                                        (int64_t)((Reg1_ud << 32) |
                                                  Reg0_ud);
                                    Reg0_ud = (uint32_t)resTemp;
                                    Reg1_ud = (uint32_t)(resTemp >> 32);
                                 ''')
    buildMult4InstUnCc("smlsldx", '''resTemp =
                                         sext<16>(bits(Reg2, 15, 0)) *
                                         sext<16>(bits(Reg3, 31, 16)) -
                                         sext<16>(bits(Reg2, 31, 16)) *
                                         sext<16>(bits(Reg3, 15, 0)) +
                                         (int64_t)((Reg1_ud << 32) |
                                                   Reg0_ud);
                                     Reg0_ud = (uint32_t)resTemp;
                                     Reg1_ud = (uint32_t)(resTemp >> 32);
                                  ''')
    buildMult4InstUnCc("smmla", '''Reg0 = resTemp =
                                       ((int64_t)(Reg3_ud << 32) +
                                        (int64_t)Reg1_sw *
                                        (int64_t)Reg2_sw) >> 32;
                                ''')
    buildMult4InstUnCc("smmlar", '''Reg0 = resTemp =
                                        ((int64_t)(Reg3_ud << 32) +
                                         (int64_t)Reg1_sw *
                                         (int64_t)Reg2_sw +
                                         ULL(0x80000000)) >> 32;
                                 ''')
    buildMult4InstUnCc("smmls", '''Reg0 = resTemp =
                                       ((int64_t)(Reg3_ud << 32) -
                                        (int64_t)Reg1_sw *
                                        (int64_t)Reg2_sw) >> 32;
                                ''')
    buildMult4InstUnCc("smmlsr", '''Reg0 = resTemp =
                                        ((int64_t)(Reg3_ud << 32) -
                                         (int64_t)Reg1_sw *
                                         (int64_t)Reg2_sw +
                                         ULL(0x80000000)) >> 32;
                                 ''')
    buildMult3InstUnCc("smmul", '''Reg0 = resTemp =
                                       ((int64_t)Reg1_sw *
                                        (int64_t)Reg2_sw) >> 32;
                                ''')
    buildMult3InstUnCc("smmulr", '''Reg0 = resTemp =
                                        ((int64_t)Reg1_sw *
                                         (int64_t)Reg2_sw +
                                         ULL(0x80000000)) >> 32;
                                 ''')
    buildMult3InstCc  ("smuad", '''Reg0 = resTemp =
                                        sext<16>(bits(Reg1, 15, 0)) *
                                        sext<16>(bits(Reg2, 15, 0)) +
                                        sext<16>(bits(Reg1, 31, 16)) *
                                        sext<16>(bits(Reg2, 31, 16));
                                    resTemp = bits(resTemp, 32) !=
                                              bits(resTemp, 31);
                                ''', "overflow")
    buildMult3InstCc  ("smuadx", '''Reg0 = resTemp =
                                        sext<16>(bits(Reg1, 15, 0)) *
                                        sext<16>(bits(Reg2, 31, 16)) +
                                        sext<16>(bits(Reg1, 31, 16)) *
                                        sext<16>(bits(Reg2, 15, 0));
                                    resTemp = bits(resTemp, 32) !=
                                              bits(resTemp, 31);
                                 ''', "overflow")
    buildMult3InstUnCc("smulbb", '''Reg0 = resTemp =
                                         sext<16>(bits(Reg1, 15, 0)) *
                                         sext<16>(bits(Reg2, 15, 0));
                                 ''')
    buildMult3InstUnCc("smulbt", '''Reg0 = resTemp =
                                         sext<16>(bits(Reg1, 15, 0)) *
                                         sext<16>(bits(Reg2, 31, 16));
                                 ''')
    buildMult3InstUnCc("smultb", '''Reg0 = resTemp =
                                         sext<16>(bits(Reg1, 31, 16)) *
                                         sext<16>(bits(Reg2, 15, 0));
                                 ''')
    buildMult3InstUnCc("smultt", '''Reg0 = resTemp =
                                         sext<16>(bits(Reg1, 31, 16)) *
                                         sext<16>(bits(Reg2, 31, 16));
                                 ''')
    buildMult4Inst    ("smull", '''resTemp = (int64_t)Reg2_sw *
                                             (int64_t)Reg3_sw;
                                   Reg1 = (int32_t)(resTemp >> 32);
                                   Reg0 = (int32_t)resTemp;
                                ''', "llbit")
    buildMult3InstUnCc("smulwb", '''Reg0 = resTemp =
                                        (Reg1_sw *
                                         sext<16>(bits(Reg2, 15, 0))) >> 16;
                                 ''')
    buildMult3InstUnCc("smulwt", '''Reg0 = resTemp =
                                        (Reg1_sw *
                                         sext<16>(bits(Reg2, 31, 16))) >> 16;
                                 ''')
    buildMult3InstUnCc("smusd", '''Reg0 = resTemp =
                                        sext<16>(bits(Reg1, 15, 0)) *
                                        sext<16>(bits(Reg2, 15, 0)) -
                                        sext<16>(bits(Reg1, 31, 16)) *
                                        sext<16>(bits(Reg2, 31, 16));
                                ''')
    buildMult3InstUnCc("smusdx", '''Reg0 = resTemp =
                                        sext<16>(bits(Reg1, 15, 0)) *
                                        sext<16>(bits(Reg2, 31, 16)) -
                                        sext<16>(bits(Reg1, 31, 16)) *
                                        sext<16>(bits(Reg2, 15, 0));
                                 ''')
    buildMult4InstUnCc("umaal", '''resTemp = Reg2_ud * Reg3_ud +
                                             Reg0_ud + Reg1_ud;
                                   Reg1_ud = (uint32_t)(resTemp >> 32);
                                   Reg0_ud = (uint32_t)resTemp;
                                ''')
    buildMult4Inst    ("umlal", '''resTemp = Reg2_ud * Reg3_ud + Reg0_ud +
                                             (Reg1_ud << 32);
                                   Reg1_ud = (uint32_t)(resTemp >> 32);
                                   Reg0_ud = (uint32_t)resTemp;
                                ''', "llbit")
    buildMult4Inst    ("umull", '''resTemp = Reg2_ud * Reg3_ud;
                                   Reg1 = (uint32_t)(resTemp >> 32);
                                   Reg0 = (uint32_t)resTemp;
                                ''', "llbit")
}};