Skip to content

Commit

Permalink
Merge branch 'master' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Agony5757 committed Mar 31, 2022
2 parents 3171adf + bdfb094 commit 904722a
Show file tree
Hide file tree
Showing 14 changed files with 405,737 additions and 56 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ find_package(PythonInterp 3 REQUIRED)
find_package(PythonLibs 3 REQUIRED)
include_directories(${PYTHON_INCLUDE_DIRS})

configure_file(${CMAKE_CURRENT_SOURCE_DIR}/resource/syanten.dat ${CMAKE_CURRENT_BINARY_DIR}/resource/syanten.dat COPYONLY)

add_subdirectory(ThirdParty EXCLUDE_FROM_ALL)
add_subdirectory(Mahjong)
add_subdirectory(MahjongPyWrapper)
Expand Down
19 changes: 12 additions & 7 deletions Mahjong/Player.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ void Player::remove_听牌(BaseTile t)
// Player::get_except_tiles()
}

vector<SelfAction> Player::get_加杠()
int Player::get_normal向胡数()const {
auto &inst = Syanten::instance();
return inst.normal_round_to_win(hand, 副露s.size());
}

vector<SelfAction> Player::get_加杠() const
{
vector<SelfAction> actions;
//if (after_chipon == true) return actions;
Expand All @@ -116,7 +121,7 @@ vector<SelfAction> Player::get_加杠()
return actions;
}

vector<SelfAction> Player::get_暗杠()
vector<SelfAction> Player::get_暗杠() const
{
vector<SelfAction> actions;

Expand All @@ -133,7 +138,7 @@ vector<SelfAction> Player::get_暗杠()
return actions;
}

static bool is食替(Player* player, BaseTile t)
static bool is食替(const Player* player, BaseTile t)
{
// 拿出最后一个副露
auto last_fulu = player->副露s.back();
Expand Down Expand Up @@ -192,7 +197,7 @@ static bool is食替(Player* player, BaseTile t)
throw runtime_error("最后一手既不是吃又不是碰,不考虑食替");
}

vector<SelfAction> Player::get_打牌(bool after_chipon)
vector<SelfAction> Player::get_打牌(bool after_chipon) const
{
profiler _("Player.cpp/get_discard");
vector<SelfAction> actions;
Expand All @@ -210,7 +215,7 @@ vector<SelfAction> Player::get_打牌(bool after_chipon)
return actions;
}

vector<SelfAction> Player::get_自摸(Table* table)
vector<SelfAction> Player::get_自摸(const Table* table) const
{
vector<SelfAction> actions;

Expand All @@ -227,7 +232,7 @@ vector<SelfAction> Player::get_自摸(Table* table)
return actions;
}

vector<SelfAction> Player::get_立直()
vector<SelfAction> Player::get_立直() const
{
vector<SelfAction> actions;

Expand All @@ -254,7 +259,7 @@ vector<SelfAction> Player::get_立直()
// return counter;
// }

vector<SelfAction> Player::get_九种九牌()
vector<SelfAction> Player::get_九种九牌() const
{
vector<SelfAction> actions;
if (!first_round) return actions;
Expand Down
21 changes: 11 additions & 10 deletions Mahjong/Player.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "Tile.h"
#include "macro.h"
#include "Rule.h"
#include "RoundToWin.h"

namespace_mahjong

Expand Down Expand Up @@ -59,7 +60,7 @@ class River {
return river[n];
}

inline size_t size() {
inline size_t size() const {
return river.size();
}

Expand Down Expand Up @@ -94,8 +95,8 @@ class Player {

Player();
Player(int init_score);
inline bool is_riichi() { return riichi || double_riichi; }
inline bool is振听() { return 同巡振听 || 舍牌振听 || 立直振听; }
inline bool is_riichi() const { return riichi || double_riichi; }
inline bool is振听() const { return 同巡振听 || 舍牌振听 || 立直振听; }
inline std::vector<Fulu> get_fuuros() { return 副露s; }
inline River get_river() { return river; }
std::string hand_to_string() const;
Expand All @@ -111,14 +112,14 @@ class Player {
if (riichi) 立直振听 = true;
else 同巡振听 = true;
}

int get_normal向胡数() const;
// Generate SelfAction
std::vector<SelfAction> get_加杠(); // 能否杠的过滤统一交给Table
std::vector<SelfAction> get_暗杠(); // 能否杠的过滤统一交给Table
std::vector<SelfAction> get_打牌(bool after_chipon);
std::vector<SelfAction> get_自摸(Table* table);
std::vector<SelfAction> get_立直();
std::vector<SelfAction> get_九种九牌();
std::vector<SelfAction> get_加杠() const ; // 能否杠的过滤统一交给Table
std::vector<SelfAction> get_暗杠() const ; // 能否杠的过滤统一交给Table
std::vector<SelfAction> get_打牌(bool after_chipon) const;
std::vector<SelfAction> get_自摸(const Table* table) const ;
std::vector<SelfAction> get_立直() const;
std::vector<SelfAction> get_九种九牌() const;

// Generate ResponseAction
std::vector<ResponseAction> get_荣和(Table*, Tile* tile);
Expand Down
120 changes: 120 additions & 0 deletions Mahjong/RoundToWin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include "RoundToWin.h"
#include <fstream>
#include <tuple>
#include <array>

namespace_mahjong

// little endian
constexpr static std::array<uint32_t, 9> tile_to_bit = {
0b000000000000000000000000001, // _1
0b000000000000000000000001000, // _2
0b000000000000000000001000000,
0b000000000000000001000000000,
0b000000000000001000000000000,
0b000000000001000000000000000,
0b000000001000000000000000000,
0b000001000000000000000000000,
0b001000000000000000000000000, // _9
};

std::string Syanten::code_to_string(uint32_t code) {
std::string result;
for (int i = 0; i < 9; i++) {
result += std::to_string(0b111 & code);
code >>= 3;
}
return result;
}

void Syanten::load_syanten_map() {
auto& result = syanten_map;
std::fstream f;
f.open("../resource/syanten.dat", std::ios::in);
if (!f.good()) {
throw std::runtime_error("open syanten.dat error.\nPlease put \"syanten.dat\" to current path.");
}
std::string line;
while (getline(f, line)) {
std::string code_str;
std::stringstream ss;
ss << line;
ss >> code_str;
uint32_t key = 0;
int value[4];
for (int i = 0; i < 9; i++) {
key += tile_to_bit[i] * (code_str[i] - '0');
}
ss >> value[0] >> value[1] >> value[2] >> value[3];
result[key] = std::make_tuple(value[0], value[1], value[2], value[3]);
}
f.close();
// for(auto iter=result.cbegin();iter!=result.cend();iter++){
// int value[4];
// std::tie(value[0], value[1], value[2], value[3]) = iter->second;
// std:: cout << code_to_string(iter->first) << ' ' << value[0] << ' ' << value[1]
// << ' ' << value[2] << ' ' << value[3] << std::endl;
// }
if (result.size() != 405350) {
throw std::runtime_error("syanten.dat broken!");
}
// return result;
}

void Syanten::hand_to_code(const std::vector<Tile*>& hand, uint32_t* code){
// 1m 2m 6m 9m 1s 3s 5s 7s 1p 1p 1p 1p 1z => [110001001, 101010100, 400000000, 100000000]
memset(code, 0, 4*sizeof(uint32_t));
for (auto iter = hand.cbegin(); iter != hand.cend(); iter++) {
BaseTile t = (*iter)->tile;
code[t / (_1p - _1m)] += tile_to_bit[t % (_1p - _1m)];
}
}

int Syanten::_check_normal(const uint32_t* hand_code, int num_副露) {
int ptm = 0, ptt = 0;
for (int j = 0; j < 3; j++) {
int pt1m, pt1t, pt2m, pt2t;
std::tie(pt1m, pt1t, pt2m, pt2t) = syanten_map[hand_code[j]];
if (pt1m * 2 + pt1t >= pt2m * 2 + pt2t) {
ptm += pt1m;
ptt += pt1t;
}
else {
ptm += pt2m;
ptt += pt2t;
}
}

for (int i = 0; i < 7; i++) {
int num = 0b111 & (hand_code[3] >> (3 * i));
if (num >= 3) {
ptm++; // 面子
}
else if (num >= 2) {
ptt++; // 雀头
}
}
while (ptm + ptt > 4 - num_副露 && ptt > 0) ptt--;
while (ptm + ptt > 4 - num_副露) ptm--;
return 9 - ptm * 2 - ptt - num_副露 * 2;
}

int Syanten::normal_round_to_win(const std::vector<Tile*>& hand, int num_副露) {
if (!is_loaded)
load_syanten_map();
int result = 14;
uint32_t hand_code[4];
hand_to_code(hand, hand_code);
for (int i = _1m; i <= _7z; i++) {
int num = 0b111 & (hand_code[i / (_1p - _1m)] >> (3 * (i % (_1p - _1m))));
if (num >= 2) {
hand_code[i / (_1p - _1m)] -= 2 * tile_to_bit[i % (_1p - _1m)];
result = std::min(result, _check_normal(hand_code, num_副露) - 1);
hand_code[i / (_1p - _1m)] += 2 * tile_to_bit[i % (_1p - _1m)];
}
}
result = std::min(result, _check_normal(hand_code, num_副露));
return result;
}

namespace_mahjong_end
26 changes: 26 additions & 0 deletions Mahjong/RoundToWin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef ROUNDTOWIN_H
#define ROUNDTOWIN_H

#include "Tile.h"
#include <map>

namespace_mahjong

class Syanten {
std::map<uint32_t, std::tuple<int, int, int, int>> syanten_map;
void hand_to_code(const std::vector<Tile*>& hand, /*OUT*/ uint32_t* code);
std::string code_to_string(uint32_t code);
int _check_normal(const uint32_t* hand_code, int num_副露);
void load_syanten_map();
bool is_loaded = false;
Syanten() = default;
public:
static Syanten& instance() {
static Syanten inst;
return inst;
}
int normal_round_to_win(const std::vector<Tile*>& hand, int num_副露);
};

namespace_mahjong_end
#endif // end #ifndef ROUNDTOWIN_H
2 changes: 1 addition & 1 deletion Mahjong/ScoreCounter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ static vector<pair<vector<Yaku>, int>> get_手役_from_complete_tiles(
return yaku_fus;
}

CounterResult yaku_counter(Table *table, Player &player, Tile *correspond_tile, bool 抢杠, bool 抢暗杠, Wind 自风, Wind 场风)
CounterResult yaku_counter(const Table *table, const Player &player, Tile *correspond_tile, bool 抢杠, bool 抢暗杠, Wind 自风, Wind 场风)
{
// 首先 假设进入到这个counter阶段的,至少满足了和牌条件的牌型
// 以及,是否有某种役是不确定的
Expand Down
2 changes: 1 addition & 1 deletion Mahjong/ScoreCounter.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Tile;

// turn 判定役的玩家
// correspond_tile (自摸为nullptr,荣和为荣和牌)
CounterResult yaku_counter(Table *table, Player &player, Tile* correspond_tile, bool 抢杠, bool 抢暗杠, Wind 自风, Wind 场风);
CounterResult yaku_counter(const Table *table, const Player &player, Tile* correspond_tile, bool 抢杠, bool 抢暗杠, Wind 自风, Wind 场风);

//CounterResult yaku_counter_v2(Table *table, Player &player, Tile* correspond_tile, bool 抢杠, bool 抢暗杠, Wind 自风, Wind 场风);

Expand Down
1 change: 0 additions & 1 deletion MahjongPyWrapper/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ elseif(WIN32)
endif(UNIX)

target_link_libraries(${LIB_MahjongPy} MahjongCore ${PYTHON_LIBRARIES} fmt tenhou_shuffle)

13 changes: 5 additions & 8 deletions pymahjong/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,15 @@ Note: In a Mahjong game, it is possible the game is over before a certain player
### pretrained opponents agent
We provide two pretrained models as the opponents (see the paper https://openreview.net/forum?id=pjqqxepwoMy) in the single-agent version environment.

To use the pretrained models, you need to have [PyTorch](https://pytorch.org/) installed. You can download the models from [this link](https://1drv.ms/u/s!AuxZyB8UeEtsgpNScPpUjF1c09gaZQ?e=j4lS05) and put the .model files at the same directory as your python script. The pretrained model should automatically enable CUDA if your PyTorch supports CUDA.
To use the pretrained models, you need to have [PyTorch](https://pytorch.org/) installed. You can [download the models from the GitHub release](https://github.com/Agony5757/mahjong/releases/tag/v1.0.0). The pretrained model should automatically enable CUDA if your PyTorch supports CUDA.

Variational Latent Oracle Guiding + Conservative Q-learning
- `mahjong_VLOG_CQL.pth`: Variational Latent Oracle Guiding + Conservative Q-learning
- `mahjong_VLOG_BC.pth`: Variational Latent Oracle Guiding + Behavior Cloning
```
env = pymahjong.SingleAgentMahjongEnv(opponent_agent="vlog-cql")
```

Variational Latent Oracle Guiding + Behavior Cloning
```
env = pymahjong.SingleAgentMahjongEnv(opponent_agent="vlog-bc")
env = pymahjong.SingleAgentMahjongEnv(opponent_agent=path_to_the_model_file)
```

The two models perform similarly against each other according to our test. However, the BC model plays with a much more defensive style than the CQL model.


## Multi-agent version
Expand Down
24 changes: 24 additions & 0 deletions pymahjong/base_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,30 @@ def forward(self, x):
return phi


def make_cnn(resolution, n_channels):
if resolution == "10x10":
# -------- MinAtar ---------
cnn_module_list = nn.ModuleList()
cnn_module_list.append(nn.Conv2d(n_channels, 16, 3, 1, 0))
cnn_module_list.append(nn.ReLU())
cnn_module_list.append(nn.Conv2d(16, 32, 3, 1, 0))
cnn_module_list.append(nn.ReLU())
cnn_module_list.append(nn.Conv2d(32, 128, 4, 2, 0))
cnn_module_list.append(nn.ReLU())
cnn_module_list.append(nn.Conv2d(128, 256, 2, 1, 0))
cnn_module_list.append(nn.ReLU())
cnn_module_list.append(nn.Flatten())
phi_size = 256

elif resolution == "34":
# -------- for Mahjong ---------
mahjong_net = MahjongNet(n_channels)
phi_size = mahjong_net.phi_size
return mahjong_net, phi_size

return nn.Sequential(*cnn_module_list), phi_size


class DiscreteActionQNetwork(nn.Module):
def __init__(self, input_size, output_size, hidden_layers=None, dueling=False, act_fn=nn.ReLU):
super(DiscreteActionQNetwork, self).__init__()
Expand Down
Loading

0 comments on commit 904722a

Please sign in to comment.