// Copyright (c) 2006-2007 The Regents of The University of Michigan
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met: redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer;
// redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution;
// neither the name of the copyright holders nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// Authors: Gabe Black
//          Steve Reinhardt

////////////////////////////////////////////////////////////////////
//
// Branch instructions
//

output header {{
        /**
         * Base class for branch operations.
         */
        class Branch : public SparcStaticInst
        {
          protected:
            // Constructor
            Branch(const char *mnem, ExtMachInst _machInst, OpClass __opClass) :
                SparcStaticInst(mnem, _machInst, __opClass)
            {
            }

            std::string generateDisassembly(Addr pc,
                    const SymbolTable *symtab) const;
        };

        /**
         * Base class for branch operations with an immediate displacement.
         */
        class BranchDisp : public Branch
        {
          protected:
            // Constructor
            BranchDisp(const char *mnem, ExtMachInst _machInst,
                    OpClass __opClass) :
                Branch(mnem, _machInst, __opClass)
            {
            }

            std::string generateDisassembly(Addr pc,
                    const SymbolTable *symtab) const;

            int32_t disp;
        };

        /**
         * Base class for branches with n bit displacements.
         */
        template<int bits>
        class BranchNBits : public BranchDisp
        {
          protected:
            // Constructor
            BranchNBits(const char *mnem, ExtMachInst _machInst,
                    OpClass __opClass) :
                BranchDisp(mnem, _machInst, __opClass)
            {
                disp = sext<bits + 2>((_machInst & mask(bits)) << 2);
            }
        };

        /**
         * Base class for 16bit split displacements.
         */
        class BranchSplit : public BranchDisp
        {
          protected:
            // Constructor
            BranchSplit(const char *mnem, ExtMachInst _machInst,
                    OpClass __opClass) :
                BranchDisp(mnem, _machInst, __opClass)
            {
                disp = sext<18>((D16HI << 16) | (D16LO << 2));
            }
        };

        /**
         * Base class for branches that use an immediate and a register to
         * compute their displacements.
         */
        class BranchImm13 : public Branch
        {
          protected:
            // Constructor
            BranchImm13(const char *mnem, ExtMachInst _machInst, OpClass __opClass) :
                Branch(mnem, _machInst, __opClass), imm(sext<13>(SIMM13))
            {
            }

            std::string generateDisassembly(Addr pc,
                    const SymbolTable *symtab) const;

            int32_t imm;
        };
}};

output decoder {{

        template class BranchNBits<19>;

        template class BranchNBits<22>;

        template class BranchNBits<30>;

        std::string
        Branch::generateDisassembly(Addr pc, const SymbolTable *symtab) const
        {
            std::stringstream response;

            printMnemonic(response, mnemonic);
            printRegArray(response, _srcRegIdx, _numSrcRegs);
            if (_numDestRegs && _numSrcRegs)
                    response << ", ";
            printDestReg(response, 0);

            return response.str();
        }

        std::string
        BranchImm13::generateDisassembly(Addr pc,
                const SymbolTable *symtab) const
        {
            std::stringstream response;

            printMnemonic(response, mnemonic);
            printRegArray(response, _srcRegIdx, _numSrcRegs);
            if (_numSrcRegs > 0)
                response << ", ";
            ccprintf(response, "0x%x", imm);
            if (_numDestRegs > 0)
                response << ", ";
            printDestReg(response, 0);

            return response.str();
        }

        std::string
        BranchDisp::generateDisassembly(Addr pc,
                const SymbolTable *symtab) const
        {
            std::stringstream response;
            std::string symbol;
            Addr symbolAddr;

            Addr target = disp + pc;

            printMnemonic(response, mnemonic);
            ccprintf(response, "0x%x", target);

            if (symtab &&
                    symtab->findNearestSymbol(target, symbol, symbolAddr)) {
                ccprintf(response, " <%s", symbol);
                if (symbolAddr != target)
                    ccprintf(response, "+%d>", target - symbolAddr);
                else
                    ccprintf(response, ">");
            }

            return response.str();
        }
}};

def template JumpExecute {{
        Fault %(class_name)s::execute(%(CPU_exec_context)s *xc,
                Trace::InstRecord *traceData) const
        {
            // Attempt to execute the instruction
            Fault fault = NoFault;

            %(op_decl)s;
            %(op_rd)s;

            %(code)s;

            if (fault == NoFault) {
                // Write the resulting state to the execution context
                %(op_wb)s;
            }

            return fault;
        }
}};

def template BranchExecute {{
        Fault
        %(class_name)s::execute(%(CPU_exec_context)s *xc,
                Trace::InstRecord *traceData) const
        {
            // Attempt to execute the instruction
            Fault fault = NoFault;

            %(op_decl)s;
            %(op_rd)s;

            if (%(cond)s) {
                %(code)s;
            } else {
                %(fail)s;
            }

            if (fault == NoFault) {
                // Write the resulting state to the execution context
                %(op_wb)s;
            }

            return fault;
        }
}};

def template BranchDecode {{
    if (A)
        return new %(class_name)sAnnul("%(mnemonic)s,a", machInst);
    else
        return new %(class_name)s("%(mnemonic)s", machInst);
}};

// Primary format for branch instructions:
def format Branch(code, *opt_flags) {{
    code = 'NNPC = NNPC;\n' + code
    (usesImm, code, immCode,
     rString, iString) = splitOutImm(code)
    iop = InstObjParams(name, Name, 'Branch', code, opt_flags)
    header_output = BasicDeclare.subst(iop)
    decoder_output = BasicConstructor.subst(iop)
    exec_output = JumpExecute.subst(iop)
    if usesImm:
        imm_iop = InstObjParams(name, Name + 'Imm', 'BranchImm' + iString,
                immCode, opt_flags)
        header_output += BasicDeclare.subst(imm_iop)
        decoder_output += BasicConstructor.subst(imm_iop)
        exec_output += JumpExecute.subst(imm_iop)
        decode_block = ROrImmDecode.subst(iop)
    else:
        decode_block = BasicDecode.subst(iop)
}};

let {{
    def doBranch(name, Name, base, cond,
            code, annul_code, fail, annul_fail, opt_flags):
        #@todo: add flags and branchTarget() for DirectCntrl branches
        #       the o3 model can take advantage of this annotation if
        #       done correctly

        iop = InstObjParams(name, Name, base,
                {"code": code,
                 "fail": fail,
                 "cond": cond
                },
                opt_flags)
        header_output = BasicDeclareWithMnemonic.subst(iop)
        decoder_output = BasicConstructorWithMnemonic.subst(iop)
        exec_output = BranchExecute.subst(iop)
        if annul_code == "None":
            decode_block = BasicDecodeWithMnemonic.subst(iop)
        else:
            decode_block = BranchDecode.subst(iop)

        if annul_code != "None":
            iop = InstObjParams(name + ',a', Name + 'Annul', base,
                    {"code": annul_code,
                     "fail": annul_fail,
                     "cond": cond
                    },
                    opt_flags)
            header_output += BasicDeclareWithMnemonic.subst(iop)
            decoder_output += BasicConstructorWithMnemonic.subst(iop)
            exec_output += BranchExecute.subst(iop)
        return (header_output, decoder_output, exec_output, decode_block)

    def doCondBranch(name, Name, base, cond, code, opt_flags):
        opt_flags += ('IsCondControl', )
        return doBranch(name, Name, base, cond, code, code,
                'NNPC = NNPC; NPC = NPC;\n',
                'NNPC = NPC + 8; NPC = NPC + 4;\n',
                opt_flags)

    def doUncondBranch(name, Name, base, code, annul_code, opt_flags):
        opt_flags += ('IsUncondControl', )
        return doBranch(name, Name, base, "true", code, annul_code,
                ";", ";", opt_flags)

    default_branch_code = 'NNPC = PC + disp;\n'
}};

// Format for branch instructions with n bit displacements:
def format BranchN(bits, code=default_branch_code,
        test=None, annul_code=None, *opt_flags) {{
    if code == "default_branch_code":
        code = default_branch_code
    if test != "None":
        (header_output,
         decoder_output,
         exec_output,
         decode_block) = doCondBranch(name, Name,
             "BranchNBits<%d>" % bits, test, code, opt_flags)
    else:
        (header_output,
         decoder_output,
         exec_output,
         decode_block) = doUncondBranch(name, Name,
             "BranchNBits<%d>" % bits, code, annul_code, opt_flags)
}};

// Format for branch instructions with split displacements:
def format BranchSplit(code=default_branch_code,
        test=None, annul_code=None, *opt_flags) {{
    if code == "default_branch_code":
        code = default_branch_code
    if test != "None":
        (header_output,
         decoder_output,
         exec_output,
         decode_block) = doCondBranch(name, Name,
             "BranchSplit", test, code, opt_flags)
    else:
        (header_output,
         decoder_output,
         exec_output,
         decode_block) = doUncondBranch(name, Name,
             "BranchSplit", code, annul_code, opt_flags)
}};