#!/usr/bin/env python3

"""Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.

   Permission is hereby granted, free of charge, to any person obtaining a copy
   of this software and associated documentation files (the "Software"), to deal
   in the Software without restriction, including without limitation the rights
   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
   ies of the Software, and to permit persons to whom the Software is furnished
   to do so, subject to the following conditions:

   The above copyright notice and this permission notice shall be included in all
   copies or substantial portions of the Software.

   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
   PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
   FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
   COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
   IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
   CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

import sys
import pathlib
import os
import argparse
import logging
import bench
from generator import SuiteProblemGenerator
from pathlib import Path
from git_info import create_github_file
from specs import get_machine_specs

console = logging.StreamHandler()

# Parameters for output csv file
matmul_paramList = ['transA','transB','grouped_gemm','batch_count','m','n','k','alpha','lda','stride_a','beta','ldb','stride_b',
                    'ldc','stride_c','ldd','stride_d','a_type','b_type','c_type','d_type','compute_type','scaleA','scaleB','scaleC',
                    'scaleD','amaxD','activation_type','bias_vector','bias_type','rotating_buffer','gflops','us']

amax_paramList = ['m','n','type','dtype','us']

# maps of paramList
benchTrackedParams = {'matmul' : matmul_paramList,
                     'amax' : amax_paramList}

# maps of benchmark executable
benchExecs = {'matmul' : 'hipblaslt-bench',
              'amax' : 'hipblaslt-bench-extop-amax'}

#d is default dict or contains a superset of trackedParams
def extractTrackedParams(d, trackedParam):
    return [d[p] for p in trackedParam]

def runBenchmark(benchType, args, executable_folder, probYamlFolder, out_csv_File, write_csv_header):

    # get the actual exec-path of this bench type
    benchExec = benchExecs[benchType]
    benchCmd = pathlib.Path(os.path.join(executable_folder, benchExec)).resolve()

    csvKeys, benchResults, success = bench.run_bench(benchCmd, probYamlFolder, args, True)

    header = ''
    header += ','.join([str(key) for key in csvKeys])+'\n'

    # get the actual params we want to track and show in csv of this bench type
    trackedParams = benchTrackedParams[benchType]
    content = ''
    for eachResult in benchResults:
        extracted = extractTrackedParams(eachResult, trackedParams)
        content += ','.join([str(e) for e in extracted])+'\n'

    if out_csv_File is not None:
        if write_csv_header is True:
            out_csv_File.write(header)
        out_csv_File.write(content)
    else:
        print(header)
        print(content)

def command_perf(arguments, probYaml_foler):
    """Run bench"""

    if arguments.workspace:
        print(f'Output data to {arguments.workspace}')
    else:
        print("Workspace not set. use -w /path/of/workspace")
        return

    if arguments.execFolder:
        print(f'Will call cpp bench executable from folder {arguments.execFolder}')
    else:
        print("execFolder not set. use -e /path/of/execFolder")
        return

    if arguments.suite is None:
        print("test problems not set, use --suite for testing problems")
        return
    else:
        generator = SuiteProblemGenerator(arguments.suite)

    out_folder = arguments.workspace
    exec_folder = arguments.execFolder

    needExportCSV = (arguments.csv == True) or (arguments.pts == True)

    logging.info("Start bench.")

    for problemSet in generator.generate_problemSet():
        pSetName = problemSet.get_problemset_name()
        print("\n==================================\n|| Running benchmarks suite {}\n==================================".format(pSetName))
        pTypeName = problemSet.benchType
        subDirectory = os.path.join(out_folder, "hipBLASLt_PTS_Benchmarks_"+pTypeName, arguments.tag)
        Path(subDirectory).mkdir(parents=True, exist_ok=True)

        # check if we need to output csv files
        # outputName, _ = os.path.splitext(os.path.basename(filename))
        out_csv_file = os.path.join(subDirectory, pSetName+'_benchmark.csv') if needExportCSV else ""
        csv_file = None if out_csv_file == "" else open(out_csv_file, 'w')

        # only the first time we need to write header
        writeCSVHeader = True
        for p in problemSet.generate_problems():
            runBenchmark(pTypeName, p.args, exec_folder, probYaml_foler, csv_file, writeCSVHeader)
            writeCSVHeader = False

        if csv_file is not None:
            print("\nResults written to {}".format(csv_file.name))
            csv_file.close()

        # check if we need to output other files for pts
        if arguments.pts == True:
            # Will only be correct if script is run from directory of the git repo associated
            # with the hipblaslt-bench executable
            create_github_file(os.path.join(subDirectory, 'hipBLASLt-commit-hash.txt'))
            get_machine_specs(os.path.join(subDirectory, 'specs.txt'))

    logging.info("Finish bench.")


def main():

    parser = argparse.ArgumentParser(prog='hipblaslt-perf')

    dir_of_this_script = pathlib.Path(os.path.dirname(os.path.realpath(__file__))).resolve()
    dir_of_this_repo = pathlib.Path(os.path.join(dir_of_this_script, '../../../')).resolve()
    probYaml_folder = pathlib.Path(os.path.join(dir_of_this_script, 'problems/')).resolve()
    executable_folder = pathlib.Path(os.path.join(dir_of_this_repo, "build/release/clients/staging/")).resolve()
    # print(f'dir of this script = {dir_of_this_script}')
    # print(f'dir of this repos = {dir_of_this_repo}')
    # print(f'dir of bench yaml problems = {probYaml_folder}')

    parser.add_argument('-w',
                        '--workspace',
                        type=str,
                        help='workspace folder keeping the perf data',
                        default=os.path.join(dir_of_this_repo, "hipBLASLt_benchmark/"))

    # this argument is mainly used in jenkins/pts. For local dev, just use the default
    parser.add_argument('-e',
                        '--execFolder',
                        type=str,
                        help='folder where the cpp bench executables are located',
                        default=executable_folder)

    parser.add_argument('--suite',
                        type=str,
                        action='append')

    parser.add_argument('--tag',
                        type=str,
                        help='subfolder under workspace, (optional, useful for comparing two commits)',
                        default='')

    parser.add_argument('--csv',
                        help='dump result to csv files, default is True',
                        action='store_true',
                        default=True)

    parser.add_argument('--pts',
                        help='dump several required files for pts system: csv, commit-hash, spec, default is False',
                        action='store_true',
                        default=False)

    arguments = parser.parse_args()

    command_perf(arguments, probYaml_folder)

    sys.exit(0)

if __name__=='__main__':
    logging.basicConfig(filename='hipblaslt-perf.log',
                        format='%(asctime)s %(levelname)s: %(message)s',
                        level=logging.DEBUG)

    console.setLevel(logging.WARNING)
    console.setFormatter(logging.Formatter('%(levelname)-8s: %(message)s'))
    logging.getLogger('').addHandler(console)

    main()
