# -*- coding: utf-8 -*-
"""
Created on Sat Jul  7 13:41:31 2018

@author: david
"""

from simulate_swiss import simulate_tournament, milvang_win_or_draw_prob
import math
import numpy as np
import time
import re

# The full set of 10000 trials on all the playing fields and
# acceleration methods takes about 8 hours on my laptop.
num_trials = 10000
num_rounds = 6
modelled_rating_floor = 700
rating_diff_checks = [200, 250, 300, 350, 400, 450, 500]

player_sets = ["christmas_2017",
               "linear",
               "perth_open_2018",
               "wa_open_2018",
               "willetton_2017"]


output_folder = "summary_stats"

acceleration_styles = ["none", "aus1", "aus1_3", "aus2", "permanent1",
                       "permanent2", "bonham", "bakuish3", "bakuish2",
                       "bakuish1"]

gimme_definitions = ["85", "round1", "bottom_half", "non_accel"]

num_gimme_definitions = len(gimme_definitions)


# Import the inequality indices.
inequality_index = {}

with open("inequality_ginis.csv") as f:
  for line in f:
    this_cells = line.split(",")
    inequality_index[",".join(this_cells[0:6])] = float(this_cells[6].strip())


def lookup_ineq_index(gimmes):
  gimmes.sort(key=lambda x: -x)
  key = ",".join([str(g) for g in gimmes])
  # Love too global variable
  return inequality_index[key]


def gimme_85(this_game, num_players):
  if this_game["expected_score"] > 0.85:
    return 1
  else:
    return 0


def gimme_round1(this_game, num_players):
  # Sorry for using a global variable.  
  if this_game["expected_score"] > gimme_threshold:
    return 1
  else:
    return 0


def gimme_bottom_half(this_game, num_players):
  if this_game["opponent"] >= num_players/2:
    return 1
  else:
    return 0


def gimme_non_accel(this_game, num_players):
  if this_game["opponent"] >= 2*math.ceil(num_players/4):
    return 1
  else:
    return 0


def import_players_from_file(in_file):
  """Assumes a Vega-like player file, reads ELONAT"""
  players = []
  i = 0
  with open(in_file) as f:
    for line in f:
      if i > 0:
        this_cells = line.split(";")
        players.append({"name": this_cells[0].strip(),
                        "rating": int(this_cells[9].strip())})
      
      i += 1
  return sorted(players, key=lambda p: -p["rating"])


def accel_bonham(players, this_round):
  """Outputs a list containing virtual points to assign to each player"""
  num_players = len(players)
  num_accelerated = 2*math.ceil(num_players/4)
  
  virtual_points = [0] * num_players
  
  if this_round < 3:
    for i in range(num_accelerated):
      virtual_points[i] = 1.
  elif this_round < 4:
    for i in range(num_accelerated):
      virtual_points[i] = 0.5
  elif this_round < 5:
    perfect_bottom_halfer = False
    for i in range(num_accelerated, num_players):
      if players[i]["score"] == this_round - 1:
        perfect_bottom_halfer = True
        break
    
    if perfect_bottom_halfer:
      for i in range(num_accelerated):
        virtual_points[i] = 0.5
  
  return virtual_points


def accel_aus1(players, this_round):
  num_players = len(players)
  num_accelerated = 2*math.ceil(num_players/4)
  virtual_points = [0] * num_players
  if this_round < 3:
    for i in range(num_accelerated):
      virtual_points[i] = 1.
  
  return virtual_points
  

def accel_aus1_3(players, this_round):
  num_players = len(players)
  num_accelerated = 2*math.ceil(num_players/4)
  virtual_points = [0] * num_players
  if this_round < 4:
    for i in range(num_accelerated):
      virtual_points[i] = 1.
  
  return virtual_points


def accel_aus2(players, this_round):
  num_players = len(players)
  num_accelerated = 2*math.ceil(num_players/4)
  virtual_points = [0] * num_players
  if this_round < 3:
    for i in range(num_accelerated):
      virtual_points[i] = 2.
  
  return virtual_points


def accel_bakuish3(players, this_round):
  num_players = len(players)
  num_accelerated = 2*math.ceil(num_players/4)
  virtual_points = [0] * num_players
  if this_round < 4:
    for i in range(num_accelerated):
      virtual_points[i] = 1.
  elif this_round < 5:
    for i in range(num_accelerated):
      virtual_points[i] = 0.5
  
  return virtual_points


def accel_bakuish2(players, this_round):
  num_players = len(players)
  num_accelerated = 2*math.ceil(num_players/4)
  virtual_points = [0] * num_players
  if this_round < 3:
    for i in range(num_accelerated):
      virtual_points[i] = 1.
  elif this_round < 4:
    for i in range(num_accelerated):
      virtual_points[i] = 0.5
  
  return virtual_points


def accel_bakuish1(players, this_round):
  num_players = len(players)
  num_accelerated = 2*math.ceil(num_players/4)
  virtual_points = [0] * num_players
  if this_round < 1:
    for i in range(num_accelerated):
      virtual_points[i] = 1.
  elif this_round < 2:
    for i in range(num_accelerated):
      virtual_points[i] = 0.5
  
  return virtual_points


def accel_permanent1(players, this_round):
  num_players = len(players)
  num_accelerated = 2*math.ceil(num_players/4)
  virtual_points = [0] * num_players
  for i in range(num_accelerated):
    virtual_points[i] = 1.
  
  return virtual_points


def accel_permanent2(players, this_round):
  num_players = len(players)
  num_accelerated = 2*math.ceil(num_players/4)
  virtual_points = [0] * num_players
  for i in range(num_accelerated):
    virtual_points[i] = 2.
  
  return virtual_points


def accel_none(players, this_round):
  return [0] * len(players)


for player_set in player_sets:
  print("Starting {}".format(player_set))
  base_players = import_players_from_file("sample_files/players_{}.veg".format(player_set))
  
  num_players = len(base_players)
  num_accelerated = 2*math.ceil(num_players/4)
  
  # Define the gimme threshold, based on the lowest expected score amongst the
  # top six players in their round one games, assuming all are Black.
  delta = math.floor(num_players/2)
  min_expected = 1.
  for i in range(6):
    probs = milvang_win_or_draw_prob(base_players[i+delta]["rating"],
                                     base_players[i]["rating"])
    this_expected = 1.0 - 0.5*(probs[0] + probs[1])
    if this_expected < min_expected:
      min_expected = this_expected
  
  time_start = time.time()
  
  
  for accel in acceleration_styles:
    print("Starting trials for {}".format(accel))
    function_string = "accel_{}".format(accel)  
    
    player_scores = [[0]*num_trials for i in range(num_players)]
    player_wins   = [0]*num_players
    player_top3s  = [0]*num_players
    
    # The following will all contain counts for number of gimmes equal to
    # 0, 1, 2, 3+.
    player_gimmes =       [[[0 for i in range(num_gimme_definitions)] for j in range(4)] for k in range(num_players)]
    player_gimmes_wins  = [[[0 for i in range(num_gimme_definitions)] for j in range(4)] for k in range(num_players)]
    player_gimmes_top3s = [[[0 for i in range(num_gimme_definitions)] for j in range(4)] for k in range(num_players)]
    
    rating_diffs = [[0 for j in rating_diff_checks] for i in range(num_rounds)]
    mean_diffs = [0]*num_rounds
    
    mean_inequalities = [0 for i in range(num_gimme_definitions)]
    
    for i_trial in range(num_trials):
      if i_trial % 1000 == 0:
        print("Starting sim {}".format(i_trial))
      
      players = simulate_tournament(base_players,
                                    acceleration = globals()[function_string],
                                    trf_file_base_name = "tournament/sim",
                                    num_rounds = num_rounds,
                                    misrate_player = False)
      
      permanent_accel = False
      num_eligible = num_players
      if re.search("permanent", accel):
        # With permanent acceleration, only the accelerated players are eligible
        # to win or finish top 3.
        permanent_accel = True
        num_eligible = num_accelerated
      
      scores = [-p["score"] for p in players[0:num_eligible]]
      sorted_i = list(np.argsort(scores))
      ranks = [sorted_i.index(i) for i in range(num_eligible)]
      
      scores = [-s for s in scores]
      score_win = np.max(scores)
      score_top3 = scores[ranks.index(2)]
      
      for i in range(num_players):
        player_scores[i][i_trial] = players[i]["score"]
        
      player_gimmes_count = [[0 for i in range(num_gimme_definitions)] for j in range(num_players)]
      this_top6_gimmes = 0
      
      for i in range(num_eligible):
        # Count gimmes
        for (k, gimme_def) in enumerate(gimme_definitions):
          gimme_fn_string = "gimme_{}".format(gimme_def)
          num_gimmes = 0
          for j in range(num_rounds):
            scores = [p["prog_scores"][j] for p in players]
            lead_score = np.max(scores)
            if players[i]["prog_scores"][j] > lead_score - 0.51:
              num_gimmes += globals()[gimme_fn_string](players[i]["games"][j], num_players)
          
          player_gimmes_count[i][k] = num_gimmes
          
          num_gimmes = min(3, num_gimmes)
          player_gimmes[i][num_gimmes][k] += 1
          
          if k == 0:
            if players[i]["score"] >= score_win:
              player_wins[i] += 1
              player_gimmes_wins[i][num_gimmes][k] += 1
            
            if players[i]["score"] >= score_top3:
              player_top3s[i] += 1
              player_gimmes_top3s[i][num_gimmes][k] += 1
      
      for i in range(num_gimme_definitions):      
        this_gimmes = [player_gimmes_count[j][i] for j in range(6)]
        mean_inequalities[i] += lookup_ineq_index(this_gimmes)
        
      # Count games with large rating differences
      for j in range(num_rounds):
        for i in range(num_players):
          r1 = players[i]["rating"]
          player2 = players[i]["games"][j]["opponent"]
          r2 = players[player2]["rating"]
          
          if r1 > r2:
            mean_diffs[j] += abs(r1 - r2)
          elif r1 == r2:
            if i < player2:
              mean_diffs[j] += abs(r1 - r2)
          
          # The loop would double-count games if abs() were used:
          this_diff = r1 - r2
          
          for (k, r_diff) in enumerate(rating_diff_checks):
            if this_diff > r_diff:
              rating_diffs[j][k] += 1
    
    player_scores_mean  = [np.mean(s) for s in player_scores]
    player_scores_stdev = [np.std(s)  for s in player_scores]
    
    for i in range(num_gimme_definitions):
      mean_inequalities[i] /= num_trials
    
    
    out_file = "{}/{}_{}_players.csv".format(output_folder, player_set, accel)
    with open(out_file, "w") as f:
      header_line = "Rank,Rating,Mean_score,Stdev_score,Sims,Won,Top3"
      for gimme_def in gimme_definitions:
        for i in range(4):
          header_line = "{},{}_{},Won_{}_{},Top3_{}_{}".format(
            header_line,
            i, gimme_def,
            i, gimme_def,
            i, gimme_def)
      
      f.write("{}\n".format(header_line))
      for (i, (m, s)) in enumerate(zip(player_scores_mean, player_scores_stdev)):
        this_line = "{},{},{:.4f},{:.4f},{},{},{}".format(
          i,
          base_players[i]["rating"],
          m,
          s,
          num_trials,
          player_wins[i],
          player_top3s[i])
        
        for (j, gimme_def) in enumerate(gimme_definitions):
          for k in range(4):
            this_line = "{},{},{},{}".format(
              this_line,
              player_gimmes[i][k][j],
              player_gimmes_wins[i][k][j],
              player_gimmes_top3s[i][k][j])
        
        f.write("{}\n".format(this_line))
    
    out_file = "{}/{}_{}_mismatches.csv".format(output_folder, player_set, accel)
    mean_diffs = ["{:.2f}".format(d/(num_trials * num_players/2)) for d in mean_diffs]
    
    with open(out_file, "w") as f:
      f.write("Stat,{}\n".format(",".join(["round{}".format(i+1) for i in range(num_rounds)])))
      for (j, r_diff) in enumerate(rating_diff_checks):
        this_line = "rating > {}".format(r_diff)
        for i in range(num_rounds):
          this_perc = 100*rating_diffs[i][j] / (num_trials * num_players/2)
          this_line = "{},{:.2f}".format(this_line, this_perc)
        this_line = "{}\n".format(this_line)
        f.write(this_line)
      
      f.write("Mean_diff,{}\n".format(",".join(mean_diffs)))
    
    out_file = "{}/{}_{}_inequality.csv".format(output_folder, player_set, accel)
    with open(out_file, "w") as f:
      f.write("Stat,Value\n")
      for (i, g) in enumerate(gimme_definitions):
        f.write("ineq_gimme_{},{:.4f}\n".format(g, mean_inequalities[i]))
      
  time_end = time.time()
  print("Calculation took {} seconds".format(time_end - time_start))
