From fbc3c8ccb8b38eb8da794a9c9261ff53d2c19fa6 Mon Sep 17 00:00:00 2001 From: Samy Sultan Date: Sun, 7 Mar 2021 18:33:09 +0200 Subject: [PATCH] pre-release version 2 of the client major changes: 1- update client version to: 2- update ttc version to: 9 3- use 4 byte packet length instead of 2 bytes 4- use advanced negotiation 5- define NCHARSET in the data type negotiation and 6- use big clear chunks 7- use more verifier type in authentication object 8- receive session properties after authentication without calling getNLSData() --- advanced_nego/advanced_nego.go | 142 ++++++ advanced_nego/auth_service.go | 157 +++++++ advanced_nego/data_integrity_service.go | 137 ++++++ advanced_nego/default_service.go | 227 +++++++++ advanced_nego/encrypt_service.go | 131 ++++++ advanced_nego/supervisor_service.go | 102 ++++ auth_object.go | 388 +++++++++++---- command.go | 139 ++++-- connection.go | 336 +++++++++++-- data_type_nego.go | 598 ++++++++++++++++++++---- db_version.go | 6 +- lob.go | 47 +- network/accept_packet.go | 56 ++- network/connect_option.go | 9 +- network/connect_packet.go | 31 +- network/data_packet.go | 58 ++- network/marker_packet.go | 21 +- network/packets.go | 6 +- network/redirect_packet.go | 2 +- network/refuse_packet.go | 2 +- network/security/diffie_hellman.go | 29 ++ network/session.go | 192 ++++---- network/session_ctx.go | 36 +- network/summary_object.go | 22 +- parameter.go | 19 +- ref_cursor.go | 8 +- tcp_protocol_nego.go | 10 + 27 files changed, 2450 insertions(+), 461 deletions(-) create mode 100644 advanced_nego/advanced_nego.go create mode 100644 advanced_nego/auth_service.go create mode 100644 advanced_nego/data_integrity_service.go create mode 100644 advanced_nego/default_service.go create mode 100644 advanced_nego/encrypt_service.go create mode 100644 advanced_nego/supervisor_service.go create mode 100644 network/security/diffie_hellman.go diff --git a/advanced_nego/advanced_nego.go b/advanced_nego/advanced_nego.go new file mode 100644 index 00000000..624e6f7f --- /dev/null +++ b/advanced_nego/advanced_nego.go @@ -0,0 +1,142 @@ +package advanced_nego + +import ( + "errors" + "github.com/sijms/go-ora/network" +) + +type AdvNego struct { + serviceList []AdvNegoService +} + +func NewAdvNego(connOption *network.ConnectionOption) (*AdvNego, error) { + output := &AdvNego{ + serviceList: make([]AdvNegoService, 5), + } + var err error + output.serviceList[1], err = NewAuthService(connOption) + if err != nil { + return nil, err + } + output.serviceList[2], err = NewEncryptService(connOption) + if err != nil { + return nil, err + } + output.serviceList[3], err = NewDataIntegrityService(connOption) + if err != nil { + return nil, err + } + output.serviceList[4], err = NewSupervisorService() + if err != nil { + return nil, err + } + return output, nil +} +func (nego *AdvNego) readHeader(session *network.Session) ([]int, error) { + num, err := session.GetInt(4, false, true) + if err != nil { + return nil, err + } + if num != 0xDEADBEEF { + return nil, errors.New("advanced negotiation error: during receive header") + } + output := make([]int, 4) + output[0], err = session.GetInt(2, false, true) + if err != nil { + return nil, err + } + output[1], err = session.GetInt(4, false, true) + if err != nil { + return nil, err + } + output[2], err = session.GetInt(2, false, true) + if err != nil { + return nil, err + } + output[3], err = session.GetInt(1, false, true) + return output, err +} +func (nego *AdvNego) readServiceHeader(session *network.Session) ([]int, error) { + output := make([]int, 3) + var err error + output[0], err = session.GetInt(2, false, true) + if err != nil { + return nil, err + } + output[1], err = session.GetInt(2, false, true) + if err != nil { + return nil, err + } + output[2], err = session.GetInt(4, false, true) + return output, err +} +func (nego *AdvNego) Read(session *network.Session) error { + header, err := nego.readHeader(session) + if err != nil { + return err + } + for i := 0; i < header[2]; i++ { + serviceHeader, err := nego.readServiceHeader(session) + if err != nil { + return err + } + if serviceHeader[2] != 0 { + return errors.New("advanced negotiation error: during receive service header") + } + err = nego.serviceList[serviceHeader[0]].readServiceData(session, serviceHeader[1]) + if err != nil { + return err + } + err = nego.serviceList[serviceHeader[0]].validateResponse() + if err != nil { + return err + } + } + for i := 1; i < 5; i++ { + err = nego.serviceList[i].activateAlgorithm() + if err != nil { + return err + } + } + if authServ, ok := nego.serviceList[1].(*authService); ok { + if authServ.active { + errors.New("advanced negotiation: advanced authentication still not supported") + if authServ.serviceName == "KERBEROS5" { + + } else if authServ.serviceName == "NTS" { + + } + } + } + return nil +} +func (nego *AdvNego) Write(session *network.Session) error { + session.ResetBuffer() + size := 0 + for i := 1; i < 5; i++ { + size = size + 8 + nego.serviceList[i].getServiceDataLength() + } + size += 13 + session.PutInt(0xDEADBEEF, 4, true, false) + session.PutInt(size, 2, true, false) + session.PutInt(nego.serviceList[1].getVersion(), 4, true, false) + session.PutInt(4, 2, true, false) + session.PutBytes(0) + err := nego.serviceList[4].writeServiceData(session) + if err != nil { + return err + } + err = nego.serviceList[1].writeServiceData(session) + if err != nil { + return err + } + err = nego.serviceList[2].writeServiceData(session) + if err != nil { + return err + } + err = nego.serviceList[3].writeServiceData(session) + if err != nil { + return err + } + return session.Write() +} diff --git a/advanced_nego/auth_service.go b/advanced_nego/auth_service.go new file mode 100644 index 00000000..42b57c6a --- /dev/null +++ b/advanced_nego/auth_service.go @@ -0,0 +1,157 @@ +package advanced_nego + +import ( + "errors" + "github.com/sijms/go-ora/network" + "runtime" +) + +type authService struct { + defaultService + status int + serviceName string + active bool +} + +func NewAuthService(connOption *network.ConnectionOption) (*authService, error) { + output := &authService{ + defaultService: defaultService{ + serviceType: 1, + level: -1, + version: 0xB200200, + }, + status: 0xFCFF, + } + //var avaAuth []string + if runtime.GOOS == "windows" { + output.availableServiceNames = []string{"", "NTS", "KERBEROS5", "TCPS"} + output.availableServiceIDs = []int{0, 1, 1, 2} + } else { + output.availableServiceNames = []string{"TCPS"} + output.availableServiceIDs = []int{2} + } + str := "" + if connOption != nil { + snConfig := connOption.SNOConfig + if snConfig != nil { + var exists bool + str, exists = snConfig["sqlnet.authentication_services"] + if !exists { + str = "" + } + } + } + //level := conops.Encryption != null ? conops.Encryption : snoConfig[]; + err := output.buildServiceList(str, false, false) + //output.selectedServ, err = output.validate(strings.Split(str,","), true) + if err != nil { + return nil, err + } + return output, nil + /* user list is found in the dictionary + sessCtx.m_conops.SNOConfig["sqlnet.authentication_services"] + */ + /* you need to confirm that every item in user list found in avaAuth list + then for each item in userList you need to get index of it in the avaAuth + return output*/ +} + +func (serv *authService) writeServiceData(session *network.Session) error { + serv.writeHeader(session, 3+(len(serv.selectedIndices)*2)) + err := serv.writeVersion(session) + if err != nil { + return err + } + err = serv.writePacketHeader(session, 2, 3) + if err != nil { + return err + } + session.PutInt(0xE0E1, 2, true, false) + err = serv.writePacketHeader(session, 2, 6) + if err != nil { + return err + } + session.PutInt(serv.status, 2, true, false) + for i := 0; i < len(serv.selectedIndices); i++ { + index := serv.selectedIndices[i] + session.PutBytes(uint8(serv.availableServiceIDs[index])) + session.PutBytes([]byte(serv.availableServiceNames[index])...) + } + return nil +} + +func (serv *authService) readServiceData(session *network.Session, subPacketNum int) error { + // read version + var err error + serv.version, err = serv.readVersion(session) + if err != nil { + return err + } + // read status + _, err = serv.readPacketHeader(session, 6) + if err != nil { + return err + } + status, err := session.GetInt(2, false, true) + if err != nil { + return err + } + if status == 0xFAFF && subPacketNum > 2 { + // get 1 byte with header + _, err = serv.readPacketHeader(session, 2) + if err != nil { + return err + } + _, err = session.GetByte() + if err != nil { + return err + } + stringLen, err := serv.readPacketHeader(session, 0) + if err != nil { + return err + } + serviceNameBytes, err := session.GetBytes(stringLen) + if err != nil { + return err + } + serv.serviceName = string(serviceNameBytes) + if subPacketNum > 4 { + _, err = serv.readVersion(session) + if err != nil { + return err + } + _, err = serv.readPacketHeader(session, 4) + if err != nil { + return err + } + _, err = session.GetInt(4, false, true) + if err != nil { + return err + } + _, err = serv.readPacketHeader(session, 4) + if err != nil { + return err + } + _, err = session.GetInt(4, false, true) + if err != nil { + return err + } + } + serv.active = true + } else { + if status != 0xFBFF { + return errors.New("advanced negotiation error: reading authentication service") + } + serv.active = false + } + return nil +} + +func (serv *authService) getServiceDataLength() int { + size := 20 + for i := 0; i < len(serv.selectedIndices); i++ { + index := serv.selectedIndices[i] + size = size + 5 + (4 + len(serv.availableServiceNames[index])) + } + return size +} diff --git a/advanced_nego/data_integrity_service.go b/advanced_nego/data_integrity_service.go new file mode 100644 index 00000000..442c5e08 --- /dev/null +++ b/advanced_nego/data_integrity_service.go @@ -0,0 +1,137 @@ +package advanced_nego + +import ( + "errors" + "fmt" + "github.com/sijms/go-ora/network" +) + +type dataIntegrityService struct { + defaultService + algoID int +} + +func NewDataIntegrityService(connOption *network.ConnectionOption) (*dataIntegrityService, error) { + output := &dataIntegrityService{ + defaultService: defaultService{ + serviceType: 3, + version: 0xB200200, + availableServiceNames: []string{"", "MD5", "SHA1", "SHA512", "SHA256", "SHA384"}, + availableServiceIDs: []int{0, 1, 3, 4, 5, 6}, + }, + } + str := "" + level := "" + if connOption != nil { + snConfig := connOption.SNOConfig + if snConfig != nil { + var exists bool + str, exists = snConfig["sqlnet.crypto_checksum_types_client"] + if !exists { + str = "" + } + level, exists = snConfig["sqlnet.crypto_checksum_client"] + if !exists { + level = "" + } + } + } + output.readAdvNegoLevel(level) + //level := conops.Encryption != null ? conops.Encryption : snoConfig[]; + err := output.buildServiceList(str, true, true) + //output.selectedServ, err = output.validate(strings.Split(str,","), true) + if err != nil { + return nil, err + } + return output, nil +} + +func (serv *dataIntegrityService) readServiceData(session *network.Session, subPacketNum int) error { + var err error + serv.version, err = serv.readVersion(session) + if err != nil { + return err + } + _, err = serv.readPacketHeader(session, 2) + if err != nil { + return err + } + resp, err := session.GetByte() + if err != nil { + return err + } + serv.algoID = int(resp) + if subPacketNum != 8 { + return nil + } + return errors.New("diffie hellman key exchange still under development") + //dhGroupGLen, err := session.GetInt(2, false, true) + //if err != nil { + // return err + //} + //dhGroupPLen, err := session.GetInt(2, false, true) + //if err != nil { + // return err + //} + //raw1, err := serv.readBytes(session) + //if err != nil { + // return err + //} + //raw2, err := serv.readBytes(session) + //if err != nil { + // return err + //} + //raw3, err := serv.readBytes(session) + //if err != nil { + // return err + //} + //raw4, err := serv.readBytes(session) + //if err != nil { + // return err + //} + //if dhGroupGLen <= 0 || dhGroupPLen <= 0 { + // return errors.New("advanced negotiation error: bad parameter from server") + //} + //byteLen := (dhGroupPLen + 7) / 8 + //if len(raw3) != byteLen || len(raw2) != byteLen { + // return errors.New("advanced negotiation error: DiffieHellman negotiation out of sync") + //} +} +func (serv *dataIntegrityService) writeServiceData(session *network.Session) error { + serv.writeHeader(session, 2) + err := serv.writeVersion(session) + if err != nil { + return err + } + err = serv.writePacketHeader(session, len(serv.selectedIndices), 1) + if err != nil { + return err + } + for i := 0; i < len(serv.selectedIndices); i++ { + index := serv.selectedIndices[i] + session.PutBytes(uint8(serv.availableServiceIDs[index])) + } + return nil +} + +func (serv *dataIntegrityService) getServiceDataLength() int { + return 12 + len(serv.selectedIndices) +} + +func (serv *dataIntegrityService) activateAlgorithm() error { + if serv.algoID == 0 { + return nil + } else { + return errors.New(fmt.Sprintf("advanced negotiation error: data integrity service algorithm: %d still not supported", serv.algoID)) + + switch serv.availableServiceNames[serv.algoID] { + case "MD5": + case "SHA1": + case "SHA512": + case "SHA256": + case "SHA384": + } + return nil + // you can use also IDs + } +} diff --git a/advanced_nego/default_service.go b/advanced_nego/default_service.go new file mode 100644 index 00000000..78db0637 --- /dev/null +++ b/advanced_nego/default_service.go @@ -0,0 +1,227 @@ +package advanced_nego + +import ( + "errors" + "fmt" + "github.com/sijms/go-ora/network" + "strings" +) + +type AdvNegoService interface { + getServiceDataLength() int + writeServiceData(session *network.Session) error + readServiceData(session *network.Session, subPacketNum int) error + validateResponse() error + getVersion() uint32 + activateAlgorithm() error +} + +type defaultService struct { + serviceType int + level int + availableServiceNames []string + availableServiceIDs []int + selectedIndices []int + version uint32 + //selectedServ map[string]int + //avaServs map[string]int +} + +func (serv *defaultService) getVersion() uint32 { + return serv.version +} +func (serv *defaultService) activateAlgorithm() error { + return nil +} +func (serv *defaultService) writePacketHeader(session *network.Session, length, _type int) error { + // the driver call Anocommunication.ValidateType(length, type); + session.PutInt(length, 2, true, false) + session.PutInt(_type, 2, true, false) + return nil +} + +func (serv *defaultService) readPacketHeader(session *network.Session, _type int) (length int, err error) { + length, err = session.GetInt(2, false, true) + if err != nil { + return + } + receivedType, err := session.GetInt(2, false, true) + if err != nil { + return 0, err + } + if receivedType != _type { + err = errors.New("advanced negotiation error: received type is not as stored type") + return + } + err = serv.validatePacketHeader(length, receivedType) + return +} +func (serv *defaultService) validatePacketHeader(length, _type int) error { + if _type < 0 || _type > 7 { + return errors.New("advanced negotiation error: cannot validate packet header") + } + switch _type { + case 0, 1: + break + case 2: + if length > 1 { + return errors.New("advanced negotiation error: cannot validate packet header") + } + case 3: + fallthrough + case 6: + if length > 2 { + return errors.New("advanced negotiation error: cannot validate packet header") + } + case 4: + fallthrough + case 5: + if length > 4 { + return errors.New("advanced negotiation error: cannot validate packet header") + } + case 7: + if length < 10 { + return errors.New("advanced negotiation error: cannot validate packet header") + } + default: + return errors.New("advanced negotiation error: cannot validate packet header") + } + return nil +} +func (serv *defaultService) writeHeader(session *network.Session, serviceSubPackets int) { + session.PutInt(serv.serviceType, 2, true, false) + session.PutInt(serviceSubPackets, 2, true, false) + session.PutInt(0, 4, true, false) +} + +func (serv *defaultService) readVersion(session *network.Session) (uint32, error) { + _, err := serv.readPacketHeader(session, 5) + if err != nil { + return 0, err + } + version, err := session.GetInt(4, false, true) + return uint32(version), err + +} +func (serv *defaultService) readBytes(session *network.Session) ([]byte, error) { + length, err := serv.readPacketHeader(session, 1) + if err != nil { + return nil, err + } + return session.GetBytes(length) +} +func (serv *defaultService) writeVersion(session *network.Session) error { + err := serv.writePacketHeader(session, 4, 5) + if err != nil { + return err + } + session.PutInt(serv.getVersion(), 4, true, false) + return nil +} +func (serv *defaultService) readAdvNegoLevel(level string) { + level = strings.ToUpper(level) + if level == "" || level == "ACCEPTED" { + serv.level = 0 + } else if level == "REJECTED" { + serv.level = 1 + } else if level == "REQUESTED" { + serv.level = 2 + } else if level == "REQUIRED" { + serv.level = 3 + } else { + serv.level = -1 + } +} + +func (serv *defaultService) buildServiceList(servList string, useLevel, useDefault bool) error { + serv.selectedIndices = make([]int, 0, 10) + //serv.selectedServ = make(map[string]int) + if useLevel { + if serv.level == 1 { + serv.selectedIndices = append(serv.selectedIndices, 0) + //serv.selectedServ[""] = 0 + return nil + } + if serv.level != 0 && serv.level != 2 && serv.level != 3 { + return errors.New(fmt.Sprintf("unsupported service level value: %d", serv.level)) + } + } + userList := strings.Split(servList, ",") + userListLength := len(userList) + for i := 0; i < userListLength; i++ { + userList[i] = strings.TrimSpace(userList[i]) + } + if userListLength > 0 && userList[userListLength-1] == "" { + userList = userList[:userListLength-1] + } + if len(userList) == 0 { + if useDefault { + for i := 0; i < len(serv.availableServiceNames); i++ { + if serv.availableServiceNames[i] == "" { + if !(useLevel && serv.level == 0) { + continue + } + } + serv.selectedIndices = append(serv.selectedIndices, i) + } + if useLevel && serv.level == 2 { + serv.selectedIndices = append(serv.selectedIndices, 0) + //serv.selectedServ[""] = 0 + } + } + return nil + } else if len(userList) == 1 { + if strings.ToUpper(userList[0]) == "ALL" { + for i := 0; i < len(serv.availableServiceNames); i++ { + if serv.availableServiceNames[i] == "" { + if !(useLevel && serv.level == 0) { + continue + } + } + serv.selectedIndices = append(serv.selectedIndices, i) + } + if useLevel && serv.level == 2 { + serv.selectedIndices = append(serv.selectedIndices, 0) + //serv.selectedServ[""] = 0 + } + return nil + } else if strings.ToUpper(userList[0]) == "NONE" { + return nil + } + } + if useLevel && serv.level == 0 { + serv.selectedIndices = append(serv.selectedIndices, 0) + //serv.selectedServ[""] = 0 + } + for _, userItem := range userList { + if userItem == "" { + return errors.New("empty authentication service") + } + found := false + for i := 0; i < len(serv.availableServiceNames); i++ { + if strings.ToUpper(userItem) == serv.availableServiceNames[i] { + serv.selectedIndices = append(serv.selectedIndices, i) + found = true + break + } + } + //for key, value := range serv.avaServs { + // if strings.ToUpper(userItem) == key { + // serv.selectedServ[key] = value + // //output = append(output, userItem) + // found = true + // break + // } + //} + if !found { + return errors.New("unsupported authentication service") + } + } + if useLevel && serv.level == 2 { + serv.selectedIndices = append(serv.selectedIndices, 0) + } + return nil +} +func (serv *defaultService) validateResponse() error { + return nil +} diff --git a/advanced_nego/encrypt_service.go b/advanced_nego/encrypt_service.go new file mode 100644 index 00000000..f47bbea1 --- /dev/null +++ b/advanced_nego/encrypt_service.go @@ -0,0 +1,131 @@ +package advanced_nego + +import ( + "errors" + "fmt" + "github.com/sijms/go-ora/network" +) + +type encryptService struct { + defaultService + algoID int +} + +func NewEncryptService(connOption *network.ConnectionOption) (*encryptService, error) { + output := &encryptService{ + defaultService: defaultService{ + serviceType: 2, + version: 0xB200200, + availableServiceNames: []string{"", "RC4_40", "RC4_56", "RC4_128", "RC4_256", + "DES40C", "DES56C", "3DES112", "3DES168", "AES128", "AES192", "AES256"}, + availableServiceIDs: []int{0, 1, 8, 10, 6, 3, 2, 11, 12, 15, 16, 17}, + }, + } + str := "" + level := "" + if connOption != nil { + snConfig := connOption.SNOConfig + if snConfig != nil { + var exists bool + str, exists = snConfig["sqlnet.encryption_types_client"] + if !exists { + str = "" + } + level, exists = snConfig["sqlnet.encryption_client"] + if !exists { + level = "" + } + } + } + output.readAdvNegoLevel(level) + //level := conops.Encryption != null ? conops.Encryption : snoConfig[]; + err := output.buildServiceList(str, true, true) + //output.selectedServ, err = output.validate(strings.Split(str,","), true) + if err != nil { + return nil, err + } + return output, nil +} + +func (serv *encryptService) readServiceData(session *network.Session, subPacketnum int) error { + var err error + serv.version, err = serv.readVersion(session) + if err != nil { + return err + } + _, err = serv.readPacketHeader(session, 2) + if err != nil { + return err + } + resp, err := session.GetByte() + if err != nil { + return err + } + serv.algoID = int(resp) + return nil +} +func (serv *encryptService) writeServiceData(session *network.Session) error { + serv.writeHeader(session, 3) + err := serv.writeVersion(session) + if err != nil { + return err + } + err = serv.writePacketHeader(session, len(serv.selectedIndices), 1) + if err != nil { + return err + } + for i := 0; i < len(serv.selectedIndices); i++ { + index := serv.selectedIndices[i] + session.PutBytes(uint8(serv.availableServiceIDs[index])) + } + // send selected driver + err = serv.writePacketHeader(session, 1, 2) + if err != nil { + return err + } + session.PutBytes(1) + return nil +} + +func (serv *encryptService) getServiceDataLength() int { + return 17 + len(serv.selectedIndices) +} + +func (serv *encryptService) activateAlgorithm() error { + if serv.algoID == 0 { + return nil + } else { + return errors.New(fmt.Sprintf("advanced negotiation error: encryption service algorithm: %d still not supported", serv.algoID)) + } + //switch (this.m_algID) + //{ + //case 1: + // this.m_sessCtx.encryptionAlg = (EncryptionAlgorithm) new RC4(true, 40); + // break; + //case 6: + // this.m_sessCtx.encryptionAlg = (EncryptionAlgorithm) new RC4(true, 256); + // break; + //case 8: + // this.m_sessCtx.encryptionAlg = (EncryptionAlgorithm) new RC4(true, 56); + // break; + //case 10: + // this.m_sessCtx.encryptionAlg = (EncryptionAlgorithm) new RC4(true, 128); + // break; + //case 11: + // this.m_sessCtx.encryptionAlg = (EncryptionAlgorithm) new DES112(); + // break; + //case 12: + // this.m_sessCtx.encryptionAlg = (EncryptionAlgorithm) new DES168(); + // break; + //case 15: + // this.m_sessCtx.encryptionAlg = (EncryptionAlgorithm) new AES(1, 1); + // break; + //case 16: + // this.m_sessCtx.encryptionAlg = (EncryptionAlgorithm) new AES(1, 2); + // break; + //case 17: + // this.m_sessCtx.encryptionAlg = (EncryptionAlgorithm) new AES(1, 3); + // break; + //} + //this.m_sessCtx.encryptionAlg.init(ano.skey, ano.getInitializationVector()); +} diff --git a/advanced_nego/supervisor_service.go b/advanced_nego/supervisor_service.go new file mode 100644 index 00000000..7f9d7450 --- /dev/null +++ b/advanced_nego/supervisor_service.go @@ -0,0 +1,102 @@ +package advanced_nego + +import ( + "errors" + "github.com/sijms/go-ora/network" +) + +type supervisorService struct { + defaultService + cid []byte + servArray []int +} + +func NewSupervisorService() (*supervisorService, error) { + output := &supervisorService{ + defaultService: defaultService{ + serviceType: 4, + version: 0xB200200, + }, + cid: []byte{0, 0, 16, 28, 102, 236, 40, 234}, + servArray: []int{4, 1, 2, 3}, + } + return output, nil +} + +func (serv *supervisorService) readServiceData(session *network.Session, subPacketNum int) error { + var err error + _, err = serv.readVersion(session) + if err != nil { + return err + } + _, err = serv.readPacketHeader(session, 6) + if err != nil { + return err + } + status, err := session.GetInt(2, false, true) + if err != nil { + return err + } + if status != 31 { + return errors.New("advanced negotiation error: reading supervisor service") + } + + _, err = serv.readPacketHeader(session, 1) + if err != nil { + return err + } + num1, err := session.GetInt(4, false, true) + if err != nil { + return err + } + num2, err := session.GetInt(2, false, true) + if err != nil { + return err + } + size, err := session.GetInt(4, false, true) + if err != nil { + return err + } + if num1 != 0xDEADBEEF || num2 != 3 { + return errors.New("advanced negotiation error: reading supervisor service") + } + serv.servArray = make([]int, size) + for i := 0; i < size; i++ { + serv.servArray[i], err = session.GetInt(2, false, true) + if err != nil { + return err + } + } + return nil +} + +func (serv *supervisorService) writeServiceData(session *network.Session) error { + serv.writeHeader(session, 3) + err := serv.writeVersion(session) + if err != nil { + return err + } + // send cid + err = serv.writePacketHeader(session, len(serv.cid), 1) + if err != nil { + return err + } + session.PutBytes(serv.cid...) + + // send the serv-array + err = serv.writePacketHeader(session, 10+len(serv.servArray)*2, 1) + if err != nil { + return err + } + session.PutInt(0xDEADBEEF, 4, true, false) + session.PutInt(3, 2, true, false) + session.PutInt(len(serv.servArray), 4, true, false) + for i := 0; i < len(serv.servArray); i++ { + session.PutInt(serv.servArray[i], 2, true, false) + } + return nil +} + +func (serv *supervisorService) getServiceDataLength() int { + return 12 + len(serv.cid) + 4 + 10 + (len(serv.servArray) * 2) +} diff --git a/auth_object.go b/auth_object.go index 61f50319..562ec4c8 100644 --- a/auth_object.go +++ b/auth_object.go @@ -5,9 +5,11 @@ import ( "crypto/aes" "crypto/cipher" "crypto/des" + "crypto/hmac" "crypto/md5" "crypto/rand" "crypto/sha1" + "crypto/sha512" "errors" "fmt" "github.com/sijms/go-ora/network" @@ -18,20 +20,32 @@ import ( // E infront of the variable means encrypted type AuthObject struct { - EServerSessKey string - EClientSessKey string - EPassword string - ServerSessKey []byte - ClientSessKey []byte - KeyHash []byte - Salt string - VerifierType int - tcpNego *TCPNego + EServerSessKey string + EClientSessKey string + EPassword string + ESpeedyKey string + ServerSessKey []byte + ClientSessKey []byte + KeyHash []byte + Salt string + pbkdf2ChkSalt string + pbkdf2VgenCount int + pbkdf2SderCount int + globalUniqueDBID string + usePadding bool + customHash bool + VerifierType int + tcpNego *TCPNego } func NewAuthObject(username string, password string, tcpNego *TCPNego, session *network.Session) (*AuthObject, error) { ret := new(AuthObject) ret.tcpNego = tcpNego + ret.usePadding = false + ret.customHash = ret.tcpNego.ServerCompileTimeCaps[4]&32 != 0 + // the parameter srvCS_Multibyte will affect may thing in the logon process + //if (Conv.GetMaxBytesPerChar((int) this.m_serverCharacterSet) > 1) + //this.m_marshallingEngine.m_bSvrCSMultibyte = true; loop := true for loop { messageCode, err := session.GetByte() @@ -59,20 +73,55 @@ func NewAuthObject(username string, password string, tcpNego *TCPNego, session * return nil, err } if bytes.Compare(key, []byte("AUTH_SESSKEY")) == 0 { - ret.EServerSessKey = string(val) + if len(ret.EServerSessKey) == 0 { + ret.EServerSessKey = string(val) + } } else if bytes.Compare(key, []byte("AUTH_VFR_DATA")) == 0 { - ret.Salt = string(val) - ret.VerifierType = num + if len(ret.Salt) == 0 { + ret.Salt = string(val) + ret.VerifierType = num + } + } else if bytes.Compare(key, []byte("AUTH_PBKDF2_CSK_SALT")) == 0 { + if len(ret.pbkdf2ChkSalt) == 0 { + ret.pbkdf2ChkSalt = string(val) + if len(ret.pbkdf2ChkSalt) != 32 { + return nil, errors.New("ORA-28041: Authentication protocol internal error") + } + } + } else if bytes.Compare(key, []byte("AUTH_PBKDF2_VGEN_COUNT")) == 0 { + if ret.pbkdf2VgenCount == 0 { + ret.pbkdf2VgenCount, err = strconv.Atoi(string(val)) + if err != nil { + return nil, errors.New("ORA-28041: Authentication protocol internal error") + } + if ret.pbkdf2VgenCount < 4096 || ret.pbkdf2VgenCount > 100000000 { + ret.pbkdf2VgenCount = 4096 + } + } + } else if bytes.Compare(key, []byte("AUTH_PBKDF2_SDER_COUNT")) == 0 { + ret.pbkdf2SderCount, err = strconv.Atoi(string(val)) + if ret.pbkdf2SderCount == 0 { + if err != nil { + return nil, errors.New("ORA-28041: Authentication protocol internal error") + } + if ret.pbkdf2SderCount < 3 || ret.pbkdf2SderCount > 100000000 { + ret.pbkdf2SderCount = 3 + } + } } } default: return nil, errors.New(fmt.Sprintf("message code error: received code %d and expected code is 8", messageCode)) } } - + if len(ret.EServerSessKey) != 64 && len(ret.EServerSessKey) != 96 { + return nil, errors.New("session key should be either 64, 96 bytes long") + } var key []byte + var speedyKey []byte padding := false var err error + if ret.VerifierType == 2361 { key, err = getKeyFromUserNameAndPassword(username, password) if err != nil { @@ -95,7 +144,18 @@ func NewAuthObject(username string, password string, tcpNego *TCPNego, session * } key = hash.Sum(nil) // 20 byte key key = append(key, 0, 0, 0, 0) // 24 byte key + } else if ret.VerifierType == 18453 { + salt, err := HexStringToBytes(ret.Salt) + if err != nil { + return nil, err + } + message := append(salt, []byte("AUTH_PBKDF2_SPEEDY_KEY")...) + speedyKey = generateSpeedyKey(message, []byte(password), ret.pbkdf2VgenCount) + buffer := append(speedyKey, salt...) + hash := sha512.New() + hash.Write(buffer) + key = hash.Sum(nil)[:32] } else { return nil, errors.New("unsupported verifier type") } @@ -105,6 +165,7 @@ func NewAuthObject(username string, password string, tcpNego *TCPNego, session * return nil, err } + // note if serverSessKey length is less than the expected length according to verifier generate random one // generate new key for client ret.ClientSessKey = make([]byte, len(ret.ServerSessKey)) for { @@ -124,13 +185,21 @@ func NewAuthObject(username string, password string, tcpNego *TCPNego, session * } // get the hash key form server and client session key - ret.KeyHash, err = CalculateKeysHash(ret.VerifierType, ret.ServerSessKey[16:], ret.ClientSessKey[16:]) + newKey, err := ret.generatePasswordEncKey() if err != nil { return nil, err } - + if ret.VerifierType == 18453 { + padding = false + } else { + padding = true + } // encrypt the password - ret.EPassword, err = EncryptPassword(password, ret.KeyHash) + ret.EPassword, err = EncryptPassword([]byte(password), newKey, true) + if err != nil { + return nil, err + } + ret.ESpeedyKey, err = EncryptPassword(speedyKey, newKey, padding) if err != nil { return nil, err } @@ -138,60 +207,52 @@ func NewAuthObject(username string, password string, tcpNego *TCPNego, session * } func (obj *AuthObject) Write(connOption *network.ConnectionOption, mode LogonMode, session *network.Session) error { - session.ResetBuffer() - keyValSize := 22 - session.PutBytes(3, 0x73, 0) - if len(connOption.UserID) > 0 { - session.PutInt(1, 1, false, false) - session.PutInt(len(connOption.UserID), 4, true, true) - } else { - session.PutBytes(0, 0) - } - - if len(connOption.UserID) > 0 && len(obj.EPassword) > 0 { - mode |= UserAndPass - } - session.PutUint(int(mode), 4, true, true) - session.PutUint(1, 1, false, false) - session.PutUint(keyValSize, 4, true, true) - session.PutBytes(1, 1) - if len(connOption.UserID) > 0 { - session.PutBytes([]byte(connOption.UserID)...) + var keys = make([]string, 0, 20) + var values = make([]string, 0, 20) + var flags = make([]uint8, 0, 20) + appendKeyVal := func(key, val string, f uint8) { + keys = append(keys, key) + values = append(values, val) + flags = append(flags, f) } index := 0 if len(obj.EClientSessKey) > 0 { - session.PutKeyValString("AUTH_SESSKEY", obj.EClientSessKey, 1) + appendKeyVal("AUTH_SESSKEY", obj.EClientSessKey, 1) index++ } if len(obj.EPassword) > 0 { - session.PutKeyValString("AUTH_PASSWORD", obj.EPassword, 0) + appendKeyVal("AUTH_PASSWORD", obj.EPassword, 0) index++ } // if newpassword encrypt and add { // session.PutKeyValString("AUTH_NEWPASSWORD", ENewPassword, 0) // index ++ //} - session.PutKeyValString("AUTH_TERMINAL", connOption.ClientData.HostName, 0) + if len(obj.ESpeedyKey) > 0 { + appendKeyVal("AUTH_PBKDF2_SPEEDY_KEY", obj.ESpeedyKey, 0) + index++ + } + appendKeyVal("AUTH_TERMINAL", connOption.ClientData.HostName, 0) index++ - session.PutKeyValString("AUTH_PROGRAM_NM", connOption.ClientData.ProgramName, 0) + appendKeyVal("AUTH_PROGRAM_NM", connOption.ClientData.ProgramName, 0) index++ - session.PutKeyValString("AUTH_MACHINE", connOption.ClientData.HostName, 0) + appendKeyVal("AUTH_MACHINE", connOption.ClientData.HostName, 0) index++ - session.PutKeyValString("AUTH_PID", fmt.Sprintf("%d", connOption.ClientData.PID), 0) + appendKeyVal("AUTH_PID", fmt.Sprintf("%d", connOption.ClientData.PID), 0) index++ - session.PutKeyValString("AUTH_SID", connOption.ClientData.UserName, 0) + appendKeyVal("AUTH_SID", connOption.ClientData.UserName, 0) index++ - session.PutKeyValString("AUTH_CONNECT_STRING", connOption.ConnectionData(), 0) + appendKeyVal("AUTH_CONNECT_STRING", connOption.ConnectionData(), 0) index++ - session.PutKeyValString("SESSION_CLIENT_CHARSET", strconv.Itoa(int(obj.tcpNego.ServerCharset)), 0) + appendKeyVal("SESSION_CLIENT_CHARSET", strconv.Itoa(int(obj.tcpNego.ServerCharset)), 0) index++ - session.PutKeyValString("SESSION_CLIENT_LIB_TYPE", "0", 0) + appendKeyVal("SESSION_CLIENT_LIB_TYPE", "0", 0) index++ - session.PutKeyValString("SESSION_CLIENT_DRIVER_NAME", connOption.ClientData.DriverName, 0) + appendKeyVal("SESSION_CLIENT_DRIVER_NAME", connOption.ClientData.DriverName, 0) index++ - session.PutKeyValString("SESSION_CLIENT_VERSION", "1.0.0.0", 0) + appendKeyVal("SESSION_CLIENT_VERSION", "2.0.0.0", 0) index++ - session.PutKeyValString("SESSION_CLIENT_LOBATTR", "1", 0) + appendKeyVal("SESSION_CLIENT_LOBATTR", "1", 0) index++ _, offset := time.Now().Zone() tz := "" @@ -206,12 +267,7 @@ func (obj *AuthObject) Write(connOption *network.ConnectionOption, mode LogonMod } tz = fmt.Sprintf("%+03d:%02d", hours, minutes) } - //if !strings.Contains(tz, ":") { - // tz += ":00" - //} - //session.PutKeyValString("AUTH_ALTER_SESSION", - // fmt.Sprintf("ALTER SESSION SET NLS_LANGUAGE='ARABIC' NLS_TERRITORY='SAUDI ARABIA' TIME_ZONE='%s'\x00", tz), 1) - session.PutKeyValString("AUTH_ALTER_SESSION", + appendKeyVal("AUTH_ALTER_SESSION", fmt.Sprintf("ALTER SESSION SET NLS_LANGUAGE='AMERICAN' NLS_TERRITORY='AMERICA' TIME_ZONE='%s'\x00", tz), 1) index++ //if (!string.IsNullOrEmpty(proxyClientName)) @@ -229,16 +285,53 @@ func (obj *AuthObject) Write(connOption *network.ConnectionOption, mode LogonMod // keys[index1] = this.m_authSerialNum; // values[index1++] = this.m_marshallingEngine.m_dbCharSetConv.ConvertStringToBytes(serialNum.ToString(), 0, serialNum.ToString().Length, true); //} - // fill remaining values with zeros - for index < keyValSize { - session.PutKeyVal(nil, nil, 0) - index++ + session.ResetBuffer() + session.PutBytes(3, 0x73, 0) + if len(connOption.UserID) > 0 { + session.PutBytes(1) + session.PutInt(len(connOption.UserID), 4, true, true) + } else { + session.PutBytes(0, 0) } - err := session.Write() - if err != nil { - return err + // if proxy auth logonMode |= 0x400 + if len(connOption.UserID) > 0 && len(obj.EPassword) > 0 { + mode |= UserAndPass + } + session.PutUint(int(mode), 4, true, true) + session.PutBytes(1) + session.PutUint(index, 4, true, true) + session.PutBytes(1, 1) + if len(connOption.UserID) > 0 { + session.PutString(connOption.UserID) + } + for i := 0; i < index; i++ { + session.PutKeyValString(keys[i], values[i], flags[i]) + } + //fill remaining values with zeros + //for index < 30 { + // session.PutKeyVal(nil, nil, 0) + // index++ + //} + return session.Write() + +} + +func generateSpeedyKey(buffer, key []byte, turns int) []byte { + mac := hmac.New(sha512.New, key) + mac.Write(append(buffer, 0, 0, 0, 1)) + firstHash := mac.Sum(nil) + tempHash := make([]byte, len(firstHash)) + copy(tempHash, firstHash) + for index1 := 2; index1 <= turns; index1++ { + //mac = hmac.New(sha512.New, []byte("ter1234")) + mac.Reset() + mac.Write(tempHash) + tempHash = mac.Sum(nil) + for index2 := 0; index2 < 64; index2++ { + firstHash[index2] = firstHash[index2] ^ tempHash[index2] + } } - return nil + return firstHash } func getKeyFromUserNameAndPassword(username string, password string) ([]byte, error) { @@ -303,6 +396,7 @@ func HexStringToBytes(input string) ([]byte, error) { } return result, nil } + func decryptSessionKey(padding bool, encKey []byte, sessionKey string) ([]byte, error) { result, err := HexStringToBytes(sessionKey) if err != nil { @@ -343,58 +437,114 @@ func EncryptSessionKey(padding bool, encKey []byte, sessionKey []byte) (string, return "", err } enc := cipher.NewCBCEncrypter(blk, make([]byte, 16)) - if padding { - sessionKey = PKCS5Padding(sessionKey, blk.BlockSize()) - } + originalLen := len(sessionKey) + sessionKey = PKCS5Padding(sessionKey, blk.BlockSize()) + //if padding { + // + //} output := make([]byte, len(sessionKey)) enc.CryptBlocks(output, sessionKey) + if !padding { + return fmt.Sprintf("%X", output[:originalLen]), nil + } return fmt.Sprintf("%X", output), nil + + //cryptoServiceProvider.Mode = CipherMode.CBC; + //cryptoServiceProvider.KeySize = key.Length * 8; + //cryptoServiceProvider.BlockSize = O5LogonHelper.d; + //cryptoServiceProvider.Key = key; + //cryptoServiceProvider.IV = O5LogonHelper.f; + //numArray = cryptoServiceProvider.CreateEncryptor().TransformFinalBlock(buffer, 0, buffer.Length); } -func EncryptPassword(password string, key []byte) (string, error) { +func EncryptPassword(password, key []byte, padding bool) (string, error) { buff1 := make([]byte, 0x10) _, err := rand.Read(buff1) //buff_1 = []byte{109, 250, 127, 252, 157, 165, 29, 6, 165, 174, 50, 93, 165, 202, 192, 100} if err != nil { return "", nil } - buffer := append(buff1, []byte(password)...) - return EncryptSessionKey(true, key, buffer) + buffer := append(buff1, password...) + return EncryptSessionKey(padding, key, buffer) } -func CalculateKeysHash(verifierType int, key1 []byte, key2 []byte) ([]byte, error) { +//func bytesToHexString(input []byte) []byte { +// byteToHex := func(x uint8) uint8 { +// x &= 0xF +// if x < 10 { +// return x + 48 +// } else { +// return x - 10 + 65 +// } +// } +// output := make([]byte, len(input)*2) +// +// for i := 0; i < len(input); i++ { +// output[i*2] = byteToHex((input[i] & 0xF0) >> 4) +// output[i*2+1] = byteToHex(input[i] & 0xF) +// } +// return output +//} +func (obj *AuthObject) generatePasswordEncKey() ([]byte, error) { hash := md5.New() - switch verifierType { - case 2361: - buffer := make([]byte, 16) - for x := 0; x < 16; x++ { - buffer[x] = key1[x] ^ key2[x] - } - - _, err := hash.Write(buffer) - if err != nil { - return nil, err - } - return hash.Sum(nil), nil - case 6949: - buffer := make([]byte, 24) - for x := 0; x < 24; x++ { - buffer[x] = key1[x] ^ key2[x] + key1 := obj.ServerSessKey + key2 := obj.ClientSessKey + start := 16 + logonCompatibility := obj.tcpNego.ServerCompileTimeCaps[4] + if logonCompatibility&32 != 0 { + var keyBuffer string + switch obj.VerifierType { + case 2361: + buffer := append(key2[:len(key2)/2], key1[:len(key1)/2]...) + keyBuffer = fmt.Sprintf("%X", buffer) + case 6949: + buffer := append(key2[:24], key1[:24]...) + keyBuffer = fmt.Sprintf("%X", buffer) + case 18453: + buffer := append(key2, key1...) + keyBuffer = fmt.Sprintf("%X", buffer) + default: + return nil, errors.New("unsupported verifier type") } - _, err := hash.Write(buffer[:16]) + df2key, err := HexStringToBytes(obj.pbkdf2ChkSalt) if err != nil { return nil, err } - ret := hash.Sum(nil) - hash.Reset() - _, err = hash.Write(buffer[16:]) - if err != nil { - return nil, err + return generateSpeedyKey(df2key, []byte(keyBuffer), obj.pbkdf2SderCount)[:32], nil + } else { + switch obj.VerifierType { + case 2361: + buffer := make([]byte, 16) + for x := 0; x < 16; x++ { + buffer[x] = key1[x+start] ^ key2[x+start] + } + _, err := hash.Write(buffer) + if err != nil { + return nil, err + } + return hash.Sum(nil), nil + case 6949: + buffer := make([]byte, 24) + for x := 0; x < 24; x++ { + buffer[x] = key1[x+start] ^ key2[x+start] + } + _, err := hash.Write(buffer[:16]) + if err != nil { + return nil, err + } + ret := hash.Sum(nil) + hash.Reset() + _, err = hash.Write(buffer[16:]) + if err != nil { + return nil, err + } + ret = append(ret, hash.Sum(nil)...) + return ret[:24], nil + default: + return nil, errors.New("unsupported verifier type") } - ret = append(ret, hash.Sum(nil)...) - return ret[:24], nil + } - return nil, nil } func (obj *AuthObject) VerifyResponse(response string) bool { @@ -409,3 +559,53 @@ func (obj *AuthObject) VerifyResponse(response string) bool { //(byte) 95, (byte) 67, (byte) 76, (byte) 73, (byte) 69, (byte) 78, (byte) 84 }; } + +func (obj *AuthObject) TestResponse(password, pbkdf2ChkSalt string, vGenCount, sDerCount int) error { + padding := false + obj.pbkdf2ChkSalt = pbkdf2ChkSalt + obj.pbkdf2VgenCount = vGenCount + obj.pbkdf2SderCount = sDerCount + obj.tcpNego = &TCPNego{ + MessageCode: 0, + ProtocolServerVersion: 0, + ProtocolServerString: "", + OracleVersion: 0, + ServerCharset: 0, + ServerFlags: 0, + CharsetElem: 0, + ServernCharset: 0, + ServerCompileTimeCaps: []byte{0, 0, 0, 0, 32}, + ServerRuntimeCaps: nil, + } + salt, err := HexStringToBytes(obj.Salt) + if err != nil { + return err + } + message := append(salt, []byte("AUTH_PBKDF2_SPEEDY_KEY")...) + speedyKey := generateSpeedyKey(message, []byte(password), obj.pbkdf2VgenCount) + + buffer := append(speedyKey, salt...) + hash := sha512.New() + hash.Write(buffer) + key := hash.Sum(nil)[:32] + obj.ServerSessKey, err = decryptSessionKey(padding, key, obj.EServerSessKey) + if err != nil { + return err + } + obj.ClientSessKey, err = decryptSessionKey(padding, key, obj.EClientSessKey) + if err != nil { + return err + } + newKey, err := obj.generatePasswordEncKey() + if err != nil { + return err + } + fmt.Println(decryptSessionKey(padding, newKey, obj.EPassword)) + + obj.EPassword, err = EncryptPassword([]byte(password), newKey, false) + if err != nil { + return err + } + obj.ESpeedyKey, err = EncryptPassword(speedyKey, newKey, false) + return err +} diff --git a/command.go b/command.go index ea57dc46..736d0fdf 100644 --- a/command.go +++ b/command.go @@ -52,7 +52,7 @@ type defaultStmt struct { queryID uint64 Pars []ParameterInfo columns []ParameterInfo - scnFromExe []int + scnForSnapshot []int arrayBindCount int } @@ -95,14 +95,26 @@ func (stmt *defaultStmt) basicWrite(exeOp int, parse, define bool) error { session.PutUint(0, 4, true, true) session.PutUint(0, 4, true, true) } - // add fetch size = max(int32) + //switch (longFetchSize) + //{ + //case -1: + // this.m_marshallingEngine.MarshalUB4((long) int.MaxValue); + // break; + //case 0: + // this.m_marshallingEngine.MarshalUB4(1L); + // break; + //default: + // this.m_marshallingEngine.MarshalUB4((long) longFetchSize); + // break; + //} + // we use here int.MaxValue session.PutUint(0x7FFFFFFF, 4, true, true) + //session.PutInt(1, 4, true, true) if len(stmt.Pars) > 0 { session.PutBytes(1) session.PutUint(len(stmt.Pars), 2, true, true) } else { session.PutBytes(0, 0) - } session.PutBytes(0, 0, 0, 0, 0) if define { @@ -117,8 +129,23 @@ func (stmt *defaultStmt) basicWrite(exeOp int, parse, define bool) error { if session.TTCVersion >= 5 { session.PutBytes(0, 0, 0, 0, 0) } + if session.TTCVersion >= 7 { + if stmt.stmtType == DML && stmt.arrayBindCount > 0 { + session.PutBytes(1) + session.PutInt(stmt.arrayBindCount, 4, true, true) + session.PutBytes(1) + } else { + session.PutBytes(0, 0, 0) + } + } + if session.TTCVersion >= 8 { + session.PutBytes(0, 0, 0, 0, 0) + } + if session.TTCVersion >= 9 { + session.PutBytes(0, 0) + } if parse { - session.PutBytes(stmt.connection.strConv.Encode(stmt.text)...) + session.PutString(string(stmt.connection.strConv.Encode(stmt.text))) } if define { session.PutBytes(0) @@ -143,19 +170,23 @@ func (stmt *defaultStmt) basicWrite(exeOp int, parse, define bool) error { case DML: fallthrough case PLSQL: - if stmt.arrayBindCount <= 1 { - al8i4[1] = 1 - } else { + if stmt.arrayBindCount > 0 { al8i4[1] = stmt.arrayBindCount + if stmt.stmtType == DML { + al8i4[9] = 0x4000 + } + } else { + al8i4[1] = 1 } case OTHERS: al8i4[1] = 1 default: + //this.m_al8i4[1] = !fetch ? 0L : noOfRowsToFetch; al8i4[1] = stmt._noOfRowsToFetch } - if len(stmt.scnFromExe) == 2 { - al8i4[5] = stmt.scnFromExe[0] - al8i4[6] = stmt.scnFromExe[1] + if len(stmt.scnForSnapshot) == 2 { + al8i4[5] = stmt.scnForSnapshot[0] + al8i4[6] = stmt.scnForSnapshot[1] } else { al8i4[5] = 0 al8i4[6] = 0 @@ -165,6 +196,11 @@ func (stmt *defaultStmt) basicWrite(exeOp int, parse, define bool) error { } else { al8i4[7] = 0 } + if exeOp&32 != 0 { + al8i4[9] |= 0x8000 + } else { + al8i4[9] &= -0x8000 + } for x := 0; x < len(al8i4); x++ { session.PutUint(al8i4[x], 4, true, true) } @@ -215,15 +251,15 @@ func NewStmt(text string, conn *Connection) *Stmt { //parse: true, //execute: true, //define: false, - //scnFromExe: make([]int, 2), + //scnForSnapshot: make([]int, 2), } ret.connection = conn ret.text = text ret._hasBLOB = false ret._hasLONG = false ret.disableCompression = true - ret.arrayBindCount = 1 - ret.scnFromExe = make([]int, 2) + ret.arrayBindCount = 0 + ret.scnForSnapshot = make([]int, 2) // get stmt type uCmdText := strings.TrimSpace(strings.ToUpper(text)) if strings.HasPrefix(uCmdText, "SELECT") || strings.HasPrefix(uCmdText, "WITH") { @@ -248,10 +284,13 @@ func NewStmt(text string, conn *Connection) *Stmt { } func (stmt *Stmt) write(session *network.Session) error { - if !stmt.parse && stmt.stmtType == DML && !stmt.reSendParDef { + if !stmt.parse && !stmt.reSendParDef { exeOf := 0 execFlag := 0 - count := stmt.arrayBindCount + count := 1 + if stmt.arrayBindCount > 0 { + count = stmt.arrayBindCount + } if stmt.stmtType == SELECT { session.PutBytes(3, 0x4E, 0) count = stmt._noOfRowsToFetch @@ -355,23 +394,24 @@ func (stmt *Stmt) write(session *network.Session) error { //} if len(stmt.Pars) > 0 { - for x := 0; x < stmt.arrayBindCount; x++ { - session.PutBytes(7) - for _, par := range stmt.Pars { - if par.DataType != RAW { - if par.DataType == REFCURSOR { - session.PutBytes(1, 0) - } else { - session.PutClr(par.BValue) - } - } - } - for _, par := range stmt.Pars { - if par.DataType == RAW { + session.PutBytes(7) + for _, par := range stmt.Pars { + if par.DataType != RAW { + if par.DataType == REFCURSOR { + session.PutBytes(1, 0) + } else { session.PutClr(par.BValue) } } } + for _, par := range stmt.Pars { + if par.DataType == RAW { + session.PutClr(par.BValue) + } + } + //for x := 0; x < stmt.arrayBindCount; x++ { + // + //} //session.PutUint(7, 1, false, false) //for _, par := range stmt.Pars { // session.PutClr(par.BValue) @@ -670,11 +710,11 @@ func (stmt *defaultStmt) read(dataSet *DataSet) error { sourceLocator: data, } session.SaveState() - dataSize, err := lob.getSize(session) + dataSize, err := lob.getSize(stmt.connection) if err != nil { return err } - lobData, err := lob.getData(session) + lobData, err := lob.getData(stmt.connection) if err != nil { return err } @@ -729,7 +769,7 @@ func (stmt *defaultStmt) read(dataSet *DataSet) error { return err } for x := 0; x < 2; x++ { - stmt.scnFromExe[x], err = session.GetInt(4, true, true) + stmt.scnForSnapshot[x], err = session.GetInt(4, true, true) if err != nil { return err } @@ -741,6 +781,15 @@ func (stmt *defaultStmt) read(dataSet *DataSet) error { } } _, err = session.GetInt(2, true, true) + if err != nil { + return err + } + //if num > 0 { + // _, err = session.GetBytes(num) + // if err != nil { + // return err + // } + //} //fmt.Println(num) //if (num > 0) // this.m_marshallingEngine.UnmarshalNBytes_ScanOnly(num); @@ -774,7 +823,20 @@ func (stmt *defaultStmt) read(dataSet *DataSet) error { } } } - + if session.TTCVersion >= 7 && stmt.stmtType == DML && stmt.arrayBindCount > 0 { + length, err := session.GetInt(4, true, true) + if err != nil { + return err + } + //for (int index = 0; index < length3; ++index) + // rowsAffectedByArrayBind[index] = this.m_marshallingEngine.UnmarshalSB8(); + for i := 0; i < length; i++ { + _, err = session.GetInt(8, true, true) + if err != nil { + return err + } + } + } case 11: err = dataSet.load(session) if err != nil { @@ -930,7 +992,7 @@ func (stmt *Stmt) NewParam(name string, val driver.Value, size int, direction Pa Name: name, Direction: direction, Flag: 3, - CharsetID: 871, + CharsetID: stmt.connection.tcpNego.ServernCharset, CharsetForm: 1, } if val == nil { @@ -972,17 +1034,20 @@ func (stmt *Stmt) NewParam(name string, val driver.Value, size int, direction Pa case string: param.DataType = NCHAR param.ContFlag = 16 - param.MaxCharLen = len(val) - param.CharsetForm = 1 + param.MaxCharLen = len([]rune(val)) + param.CharsetForm = 2 if val == "" && direction == Input { param.BValue = nil param.MaxLen = 1 } else { + tempCharset := stmt.connection.strConv.LangID + stmt.connection.strConv.LangID = param.CharsetID param.BValue = stmt.connection.strConv.Encode(val) + stmt.connection.strConv.LangID = tempCharset if size > len(val) { param.MaxCharLen = size } - param.MaxLen = param.MaxCharLen * converters.MaxBytePerChar(stmt.connection.strConv.LangID) + param.MaxLen = param.MaxCharLen * converters.MaxBytePerChar(param.CharsetID) } case []byte: param.BValue = val @@ -1008,7 +1073,7 @@ func (stmt *Stmt) AddParam(name string, val driver.Value, size int, direction Pa stmt.Pars = append(stmt.Pars, *stmt.NewParam(name, val, size, direction)) } -func (stmt *Stmt) AddRefCursorParam(_ string) { +func (stmt *Stmt) AddRefCursorParam(name string) { par := stmt.NewParam("1", nil, 0, Output) par.DataType = REFCURSOR par.ContFlag = 0 diff --git a/connection.go b/connection.go index 8426d5d2..cfd3523d 100644 --- a/connection.go +++ b/connection.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" + "github.com/sijms/go-ora/advanced_nego" "github.com/sijms/go-ora/converters" "github.com/sijms/go-ora/network" "github.com/sijms/go-ora/trace" @@ -37,17 +38,25 @@ const ( type NLSData struct { Calender string Comp string + Language string LengthSemantics string NCharConvExcep string + NCharConvImp string DateLang string Sort string Currency string DateFormat string + TimeFormat string IsoCurrency string NumericChars string DualCurrency string + UnionCurrency string Timestamp string TimestampTZ string + TTimezoneFormat string + NTimezoneFormat string + Territory string + Charset string } type Connection struct { State ConnectionState @@ -63,6 +72,7 @@ type Connection struct { dBVersion *DBVersion sessionID int serialID int + transactionID []byte strConv *converters.StringConverter NLSData NLSData } @@ -269,12 +279,34 @@ func (conn *Connection) Open() error { default: conn.LogonMode = 0 } + conn.session = network.NewSession(*conn.connOption) - err := conn.session.Connect() + session := conn.session + err := session.Connect() if err != nil { return err } + // advanced negotiation + if session.Context.ACFL0&1 != 0 && session.Context.ACFL0&4 == 0 && session.Context.ACFL1&8 == 0 { + ano, err := advanced_nego.NewAdvNego(conn.connOption) + if err != nil { + return err + } + err = ano.Write(session) + if err != nil { + return err + } + err = ano.Read(session) + if err != nil { + return err + } + //fmt.Printf("%#v\n", ano) + //return errors.New("stop connection") + //start advanced negotionation + //session.Context.Ano + } else { + } conn.tcpNego, err = NewTCPNego(conn.session) if err != nil { return err @@ -282,22 +314,24 @@ func (conn *Connection) Open() error { // create string converter object conn.strConv = converters.NewStringConverter(conn.tcpNego.ServerCharset) conn.session.StrConv = conn.strConv - - conn.dataNego, err = buildTypeNego(conn.tcpNego, conn.session) + conn.tcpNego.ServerFlags |= 2 + conn.dataNego = buildTypeNego(conn.tcpNego, conn.session) + err = conn.dataNego.write(conn.session) + if err != nil { + return err + } + err = conn.dataNego.read(conn.session) if err != nil { return err } - conn.session.TTCVersion = conn.dataNego.CompileTimeCaps[7] - if conn.tcpNego.ServerCompileTimeCaps[7] < conn.session.TTCVersion { conn.session.TTCVersion = conn.tcpNego.ServerCompileTimeCaps[7] } - //if (((int) this.m_serverCompiletimeCapabilities[15] & 1) != 0) - // this.m_marshallingEngine.HasEOCSCapability = true; - //if (((int) this.m_serverCompiletimeCapabilities[16] & 16) != 0) - // this.m_marshallingEngine.HasFSAPCapability = true; + //this.m_b32kTypeSupported = this.m_dtyNeg.m_b32kTypeSupported; + //this.m_bSupportSessionStateOps = this.m_dtyNeg.m_bSupportSessionStateOps; + //this.m_marshallingEngine.m_bServerUsingBigSCN = this.m_serverCompiletimeCapabilities[7] >= (byte) 8; err = conn.doAuth() if err != nil { @@ -309,30 +343,25 @@ func (conn *Connection) Open() error { return err } - if len(conn.SessionProperties) == 0 { - //return errors.New(fmt.Sprint("Session properties is null")) - } else { - sessionID, err := strconv.ParseUint(conn.SessionProperties["AUTH_SESSION_ID"], 10, 32) - if err != nil { - return err - } - conn.sessionID = int(sessionID) - serialNum, err := strconv.ParseUint(conn.SessionProperties["AUTH_SERIAL_NUM"], 10, 32) - if err != nil { - return err - } - conn.serialID = int(serialNum) - conn.connOption.InstanceName = conn.SessionProperties["AUTH_SC_INSTANCE_NAME"] - conn.connOption.Host = conn.SessionProperties["AUTH_SC_SERVER_HOST"] - conn.connOption.ServiceName = conn.SessionProperties["AUTH_SC_SERVICE_NAME"] - conn.connOption.DomainName = conn.SessionProperties["AUTH_SC_DB_DOMAIN"] - conn.connOption.DBName = conn.SessionProperties["AUTH_SC_DBUNIQUE_NAME"] + sessionID, err := strconv.ParseUint(conn.SessionProperties["AUTH_SESSION_ID"], 10, 32) + if err != nil { + return err } - - _, err = conn.GetNLS() + conn.sessionID = int(sessionID) + serialNum, err := strconv.ParseUint(conn.SessionProperties["AUTH_SERIAL_NUM"], 10, 32) if err != nil { return err } + conn.serialID = int(serialNum) + conn.connOption.InstanceName = conn.SessionProperties["AUTH_SC_INSTANCE_NAME"] + conn.connOption.Host = conn.SessionProperties["AUTH_SC_SERVER_HOST"] + conn.connOption.ServiceName = conn.SessionProperties["AUTH_SC_SERVICE_NAME"] + conn.connOption.DomainName = conn.SessionProperties["AUTH_SC_DB_DOMAIN"] + conn.connOption.DBName = conn.SessionProperties["AUTH_SC_DBUNIQUE_NAME"] + //_, err = conn.GetNLS() + //if err != nil { + // return err + //} return nil } @@ -427,7 +456,8 @@ func (conn *Connection) doAuth() error { conn.LogonMode = conn.LogonMode | NoNewPass conn.session.PutUint(int(conn.LogonMode), 4, true, true) conn.session.PutBytes(1, 1, 5, 1, 1) - conn.session.PutBytes([]byte(conn.conStr.UserID)...) + conn.session.PutString(conn.conStr.UserID) + //conn.session.PutBytes([]byte()...) conn.session.PutKeyValString("AUTH_TERMINAL", conn.connOption.ClientData.HostName, 0) conn.session.PutKeyValString("AUTH_PROGRAM_NM", conn.connOption.ClientData.ProgramName, 0) conn.session.PutKeyValString("AUTH_MACHINE", conn.connOption.ClientData.HostName, 0) @@ -449,7 +479,7 @@ func (conn *Connection) doAuth() error { } stop := false for !stop { - msg, err := conn.session.GetInt(1, false, false) + msg, err := conn.session.GetByte() if err != nil { return err } @@ -464,7 +494,7 @@ func (conn *Connection) doAuth() error { } stop = true case 8: - dictLen, err := conn.session.GetInt(4, true, true) + dictLen, err := conn.session.GetInt(2, true, true) if err != nil { return err } @@ -484,7 +514,22 @@ func (conn *Connection) doAuth() error { if warning != nil { fmt.Println(warning) } - stop = true + case 23: + opCode, err := conn.session.GetByte() + if err != nil { + return err + } + if opCode == 5 { + err = conn.loadNLSData() + if err != nil { + return err + } + } else { + err = conn.getServerNetworkInformation(opCode) + if err != nil { + return err + } + } default: return errors.New(fmt.Sprintf("message code error: received code %d and expected code is 8", msg)) } @@ -494,3 +539,228 @@ func (conn *Connection) doAuth() error { // conn.authObject.VerifyResponse(conn.SessionProperties["AUTH_SVR_RESPONSE"]) return nil } +func (conn *Connection) loadNLSData() error { + _, err := conn.session.GetInt(2, true, true) + if err != nil { + return err + } + _, err = conn.session.GetByte() + if err != nil { + return err + } + length, err := conn.session.GetInt(4, true, true) + if err != nil { + return err + } + _, err = conn.session.GetByte() + if err != nil { + return err + } + for i := 0; i < length; i++ { + nlsKey, nlsVal, nlsCode, err := conn.session.GetKeyVal() + if err != nil { + return err + } + conn.NLSData.SaveNLSValue(string(nlsKey), string(nlsVal), nlsCode) + } + _, err = conn.session.GetInt(4, true, true) + return err +} + +func (conn *Connection) getServerNetworkInformation(code uint8) error { + session := conn.session + if code == 0 { + _, err := session.GetByte() + return err + } + switch code - 1 { + case 1: + // receive OCOSPID + length, err := session.GetInt(2, true, true) + if err != nil { + return err + } + _, err = session.GetByte() + if err != nil { + return err + } + _, err = session.GetBytes(length) + if err != nil { + return err + } + case 3: + // receive OCSESSRET session return values + _, err := session.GetInt(2, true, true) + if err != nil { + return err + } + _, err = session.GetByte() + if err != nil { + return err + } + length, err := session.GetInt(2, true, true) + if err != nil { + return err + } + // get nls data + for i := 0; i < length; i++ { + nlsKey, nlsVal, nlsCode, err := session.GetKeyVal() + if err != nil { + return err + } + conn.NLSData.SaveNLSValue(string(nlsKey), string(nlsVal), nlsCode) + } + flag, err := session.GetInt(4, true, true) + if err != nil { + return err + } + sessionID, err := session.GetInt(4, true, true) + if err != nil { + return err + } + serialID, err := session.GetInt(2, true, true) + if err != nil { + return err + } + if flag&4 == 4 { + conn.sessionID = sessionID + conn.serialID = serialID + // save session id and serial number to connection + } + case 4: + err := conn.loadNLSData() + if err != nil { + return err + } + case 6: + length, err := session.GetInt(4, true, true) + if err != nil { + return err + } + conn.transactionID, err = session.GetClr() + if len(conn.transactionID) > length { + conn.transactionID = conn.transactionID[:length] + } + case 7: + _, err := session.GetInt(2, true, true) + if err != nil { + return err + } + _, err = session.GetByte() + if err != nil { + return err + } + _, err = session.GetInt(4, true, true) + if err != nil { + return err + } + _, err = session.GetInt(4, true, true) + if err != nil { + return err + } + _, err = session.GetByte() + if err != nil { + return err + } + _, err = session.GetDlc() + if err != nil { + return err + } + case 8: + _, err := session.GetInt(2, true, true) + if err != nil { + return err + } + _, err = session.GetByte() + if err != nil { + return err + } + } + return nil +} + +func (nls *NLSData) SaveNLSValue(key, value string, code int) { + key = strings.ToUpper(key) + if len(key) > 0 { + switch key { + case "AUTH_NLS_LXCCURRENCY": + code = 0 + case "AUTH_NLS_LXCISOCURR": + code = 1 + case "AUTH_NLS_LXCNUMERICS": + code = 2 + case "AUTH_NLS_LXCDATEFM": + code = 7 + case "AUTH_NLS_LXCDATELANG": + code = 8 + case "AUTH_NLS_LXCTERRITORY": + code = 9 + case "SESSION_NLS_LXCCHARSET": + code = 10 + case "AUTH_NLS_LXCSORT": + code = 11 + case "AUTH_NLS_LXCCALENDAR": + code = 12 + case "AUTH_NLS_LXLAN": + code = 16 + case "AL8KW_NLSCOMP": + code = 50 + case "AUTH_NLS_LXCUNIONCUR": + code = 52 + case "AUTH_NLS_LXCTIMEFM": + code = 57 + case "AUTH_NLS_LXCSTMPFM": + code = 58 + case "AUTH_NLS_LXCTTZNFM": + code = 59 + case "AUTH_NLS_LXCSTZNFM": + code = 60 + case "SESSION_NLS_LXCNLSLENSEM": + code = 61 + case "SESSION_NLS_LXCNCHAREXCP": + code = 62 + case "SESSION_NLS_LXCNCHARIMP": + code = 63 + } + } + switch code { + case 0: + nls.Currency = value + case 1: + nls.IsoCurrency = value + case 2: + nls.NumericChars = value + case 7: + nls.DateFormat = value + case 8: + nls.DateLang = value + case 9: + nls.Territory = value + case 10: + nls.Charset = value + case 11: + nls.Sort = value + case 12: + nls.Calender = value + case 16: + nls.Language = value + case 50: + nls.Comp = value + case 52: + nls.UnionCurrency = value + case 57: + nls.TimeFormat = value + case 58: + nls.Timestamp = value + case 59: + nls.TTimezoneFormat = value + case 60: + nls.NTimezoneFormat = value + case 61: + nls.LengthSemantics = value + case 62: + nls.NCharConvExcep = value + case 63: + nls.NCharConvImp = value + } +} diff --git a/data_type_nego.go b/data_type_nego.go index a99f37f7..26f498e7 100644 --- a/data_type_nego.go +++ b/data_type_nego.go @@ -1,8 +1,6 @@ package go_ora import ( - "bytes" - "encoding/binary" "errors" "fmt" "github.com/sijms/go-ora/network" @@ -10,38 +8,21 @@ import ( ) type DataTypeNego struct { - MessageCode uint8 - Server *TCPNego - TypeAndRep []int16 - RuntimeTypeAndRep []int16 - DataTypeRepFor1100 int16 - CompileTimeCaps []byte - RuntimeCap []byte - DBTimeZone []byte + MessageCode uint8 + Server *TCPNego + TypeAndRep []int16 + RuntimeTypeAndRep []int16 + DataTypeRepFor1100 int16 + DataTypeRepFor1200 int16 + CompileTimeCaps []byte + RuntimeCap []byte + DBTimeZone []byte + b32kTypeSupported bool + supportSessionStateOps bool } const bufferGrow int = 2369 -//internal TTCDataTypeNegotiation(MarshallingEngine marshallingEngine, -// byte[] serverCompileTimeCap, byte[] serverRunTimeCap, -// short networkCharSetId, -// short networkNCharSetId, -// byte networkFlags) -//: base(marshallingEngine, (byte) 2) -//{ -//this.m_clientRemoteIn = networkCharSetId; -//this.m_clientRemoteOut = networkCharSetId; -//this.m_ncharSetId = networkNCharSetId; -//this.m_clientFlags = networkFlags; -// -// -// -// -//internal override void ReInit(MarshallingEngine marshallingEngine) -//{ -//base.ReInit(marshallingEngine); -//this.m_dbTimeZoneBytes = (byte[]) null; -//} func (n *DataTypeNego) addTypeRep(dty int16, ndty int16, rep int16) { if n.TypeAndRep == nil { n.TypeAndRep = make([]int16, bufferGrow) @@ -61,22 +42,58 @@ func (n *DataTypeNego) addTypeRep(dty int16, ndty int16, rep int16) { } } -func buildTypeNego(nego *TCPNego, session *network.Session) (*DataTypeNego, error) { +func buildTypeNego(nego *TCPNego, session *network.Session) *DataTypeNego { result := DataTypeNego{ MessageCode: 2, Server: nego, TypeAndRep: make([]int16, bufferGrow), CompileTimeCaps: []byte{ - 6, 1, 0, 0, 10, 1, 1, 6, - 1, 1, 1, 1, 1, 1, 0, 0x29, - 0x90, 3, 7, 3, 0, 1, 0, 0x6B, - 1, 0, 5, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 2}, - RuntimeCap: []byte{2, 1, 0, 0, 0, 0, 0}, + 6, 1, 0, 0, 106, 1, 1, 11, + 1, 1, 1, 1, 1, 1, 0, 41, + 144, 3, 7, 3, 0, 1, 0, 235, + 1, 0, 5, 1, 0, 0, 0, 24, + 0, 0, 7, 32, 2, 58, 0, 0, 5, + //6, 1, 0, 0, 10, 1, 1, 6, + //1, 1, 1, 1, 1, 1, 0, 0x29, + //0x90, 3, 7, 3, 0, 1, 0, 0x6B, + //1, 0, 5, 1, 0, 0, 0, 0, + //0, 0, 0, 0, 1, 2 + }, + RuntimeCap: []byte{2, 1, 0, 0, 0, 0, 0}, + b32kTypeSupported: false, + supportSessionStateOps: false, + } + if len(result.Server.ServerCompileTimeCaps) <= 27 || result.Server.ServerCompileTimeCaps[27] == 0 { + result.CompileTimeCaps[27] = 0 + } + xmlTypeClientSideDecoding := false + if len(result.Server.ServerCompileTimeCaps) > 7 { + if result.Server.ServerCompileTimeCaps[7] >= 8 && xmlTypeClientSideDecoding { + result.CompileTimeCaps[36] = 4 + } else if result.Server.ServerCompileTimeCaps[7] < 7 { + result.CompileTimeCaps[36] = 0 + } + } + //this.m_clientRemoteIn = serverCharacterSet; + //this.m_clientRemoteOut = serverCharacterSet; + //this.m_ncharSetId = serverNCharSet; + //this.m_clientFlags = serverFlags; + if len(result.Server.ServerRuntimeCaps) < 1 || result.Server.ServerRuntimeCaps[1]&1 != 1 { + result.RuntimeCap[1] &= 0 + } + if len(result.Server.ServerRuntimeCaps) > 6 { + if result.Server.ServerRuntimeCaps[6]&4 == 4 { + result.RuntimeCap[6] |= 4 + result.b32kTypeSupported = true + } + if result.Server.ServerRuntimeCaps[6]&16 == 16 { + result.supportSessionStateOps = true + } + if result.Server.ServerRuntimeCaps[6]&2 == 2 { + result.RuntimeCap[6] |= 2 + } } - if result.Server.ServerCompileTimeCaps == nil || - len(result.Server.ServerCompileTimeCaps) <= 37 || - result.Server.ServerCompileTimeCaps[37]&2 != 2 { + if len(result.Server.ServerCompileTimeCaps) <= 37 || result.Server.ServerCompileTimeCaps[37]&2 != 2 { result.CompileTimeCaps[37] = 0 result.CompileTimeCaps[1] = 0 } @@ -287,6 +304,9 @@ func buildTypeNego(nego *TCPNego, session *network.Session) (*DataTypeNego, erro result.addTypeRep(575, 575, 1) result.addTypeRep(576, 576, 1) result.addTypeRep(578, 578, 1) + result.addTypeRep(563, 563, 1) + result.addTypeRep(564, 564, 1) + result.addTypeRep(579, 579, 1) result.addTypeRep(580, 580, 1) result.addTypeRep(581, 581, 1) result.addTypeRep(582, 582, 1) @@ -373,51 +393,396 @@ func buildTypeNego(nego *TCPNego, session *network.Session) (*DataTypeNego, erro result.addTypeRep(231, 231, 1) result.addTypeRep(232, 231, 1) result.addTypeRep(233, 233, 1) + result.addTypeRep(252, 252, 1) result.addTypeRep(241, 109, 1) result.addTypeRep(515, 0, 0) + + //result.addTypeRep(1, 1, 1) + //result.addTypeRep(2, 2, 10) + //result.addTypeRep(8, 8, 1) + //result.addTypeRep(12, 12, 10) + //result.addTypeRep(23, 23, 1) + //result.addTypeRep(24, 24, 1) + //result.addTypeRep(25, 25, 1) + //result.addTypeRep(26, 26, 1) + //result.addTypeRep(27, 27, 1) + //result.addTypeRep(28, 28, 1) + //result.addTypeRep(29, 29, 1) + //result.addTypeRep(30, 30, 1) + //result.addTypeRep(31, 31, 1) + //result.addTypeRep(32, 32, 1) + //result.addTypeRep(33, 33, 1) + //result.addTypeRep(10, 10, 1) + //result.addTypeRep(11, 11, 1) + //result.addTypeRep(40, 40, 1) + //result.addTypeRep(41, 41, 1) + //result.addTypeRep(117, 117, 1) + //result.addTypeRep(120, 120, 1) + //result.addTypeRep(290, 290, 1) + //result.addTypeRep(291, 291, 1) + //result.addTypeRep(292, 292, 1) + //result.addTypeRep(293, 293, 1) + //result.addTypeRep(294, 294, 1) + //result.addTypeRep(298, 298, 1) + //result.addTypeRep(299, 299, 1) + //result.addTypeRep(300, 300, 1) + //result.addTypeRep(301, 301, 1) + //result.addTypeRep(302, 302, 1) + //result.addTypeRep(303, 303, 1) + //result.addTypeRep(304, 304, 1) + //result.addTypeRep(305, 305, 1) + //result.addTypeRep(306, 306, 1) + //result.addTypeRep(307, 307, 1) + //result.addTypeRep(308, 308, 1) + //result.addTypeRep(309, 309, 1) + //result.addTypeRep(310, 310, 1) + //result.addTypeRep(311, 311, 1) + //result.addTypeRep(312, 312, 1) + //result.addTypeRep(313, 313, 1) + //result.addTypeRep(315, 315, 1) + //result.addTypeRep(316, 316, 1) + //result.addTypeRep(317, 317, 1) + //result.addTypeRep(318, 318, 1) + //result.addTypeRep(319, 319, 1) + //result.addTypeRep(320, 320, 1) + //result.addTypeRep(321, 321, 1) + //result.addTypeRep(322, 322, 1) + //result.addTypeRep(323, 323, 1) + //result.addTypeRep(327, 327, 1) + //result.addTypeRep(328, 328, 1) + //result.addTypeRep(329, 329, 1) + //result.addTypeRep(331, 331, 1) + //result.addTypeRep(333, 333, 1) + //result.addTypeRep(334, 334, 1) + //result.addTypeRep(335, 335, 1) + //result.addTypeRep(336, 336, 1) + //result.addTypeRep(337, 337, 1) + //result.addTypeRep(338, 338, 1) + //result.addTypeRep(339, 339, 1) + //result.addTypeRep(340, 340, 1) + //result.addTypeRep(341, 341, 1) + //result.addTypeRep(342, 342, 1) + //result.addTypeRep(343, 343, 1) + //result.addTypeRep(344, 344, 1) + //result.addTypeRep(345, 345, 1) + //result.addTypeRep(346, 346, 1) + //result.addTypeRep(348, 348, 1) + //result.addTypeRep(349, 349, 1) + //result.addTypeRep(354, 354, 1) + //result.addTypeRep(355, 355, 1) + //result.addTypeRep(359, 359, 1) + //result.addTypeRep(363, 363, 1) + //result.addTypeRep(380, 380, 1) + //result.addTypeRep(381, 381, 1) + //result.addTypeRep(382, 382, 1) + //result.addTypeRep(383, 383, 1) + //result.addTypeRep(384, 384, 1) + //result.addTypeRep(385, 385, 1) + //result.addTypeRep(386, 386, 1) + //result.addTypeRep(387, 387, 1) + //result.addTypeRep(388, 388, 1) + //result.addTypeRep(389, 389, 1) + //result.addTypeRep(390, 390, 1) + //result.addTypeRep(391, 391, 1) + //result.addTypeRep(393, 393, 1) + //result.addTypeRep(394, 394, 1) + //result.addTypeRep(395, 395, 1) + //result.addTypeRep(396, 396, 1) + //result.addTypeRep(397, 397, 1) + //result.addTypeRep(398, 398, 1) + //result.addTypeRep(399, 399, 1) + //result.addTypeRep(400, 400, 1) + //result.addTypeRep(401, 401, 1) + //result.addTypeRep(404, 404, 1) + //result.addTypeRep(405, 405, 1) + //result.addTypeRep(406, 406, 1) + //result.addTypeRep(407, 407, 1) + //result.addTypeRep(413, 413, 1) + //result.addTypeRep(414, 414, 1) + //result.addTypeRep(415, 415, 1) + //result.addTypeRep(416, 416, 1) + //result.addTypeRep(417, 417, 1) + //result.addTypeRep(418, 418, 1) + //result.addTypeRep(419, 419, 1) + //result.addTypeRep(420, 420, 1) + //result.addTypeRep(421, 421, 1) + //result.addTypeRep(422, 422, 1) + //result.addTypeRep(423, 423, 1) + //result.addTypeRep(424, 424, 1) + //result.addTypeRep(425, 425, 1) + //result.addTypeRep(426, 426, 1) + //result.addTypeRep(427, 427, 1) + //result.addTypeRep(429, 429, 1) + //result.addTypeRep(430, 430, 1) + //result.addTypeRep(431, 431, 1) + //result.addTypeRep(432, 432, 1) + //result.addTypeRep(433, 433, 1) + //result.addTypeRep(449, 449, 1) + //result.addTypeRep(450, 450, 1) + //result.addTypeRep(454, 454, 1) + //result.addTypeRep(455, 455, 1) + //result.addTypeRep(456, 456, 1) + //result.addTypeRep(457, 457, 1) + //result.addTypeRep(458, 458, 1) + //result.addTypeRep(459, 459, 1) + //result.addTypeRep(460, 460, 1) + //result.addTypeRep(461, 461, 1) + //result.addTypeRep(462, 462, 1) + //result.addTypeRep(463, 463, 1) + //result.addTypeRep(466, 466, 1) + //result.addTypeRep(467, 467, 1) + //result.addTypeRep(468, 468, 1) + //result.addTypeRep(469, 469, 1) + //result.addTypeRep(470, 470, 1) + //result.addTypeRep(471, 471, 1) + //result.addTypeRep(472, 472, 1) + //result.addTypeRep(473, 473, 1) + //result.addTypeRep(474, 474, 1) + //result.addTypeRep(475, 475, 1) + //result.addTypeRep(476, 476, 1) + //result.addTypeRep(477, 477, 1) + //result.addTypeRep(478, 478, 1) + //result.addTypeRep(479, 479, 1) + //result.addTypeRep(480, 480, 1) + //result.addTypeRep(481, 481, 1) + //result.addTypeRep(482, 482, 1) + //result.addTypeRep(483, 483, 1) + //result.addTypeRep(484, 484, 1) + //result.addTypeRep(485, 485, 1) + //result.addTypeRep(486, 486, 1) + //result.addTypeRep(490, 490, 1) + //result.addTypeRep(491, 491, 1) + //result.addTypeRep(492, 492, 1) + //result.addTypeRep(493, 493, 1) + //result.addTypeRep(494, 494, 1) + //result.addTypeRep(495, 495, 1) + //result.addTypeRep(496, 496, 1) + //result.addTypeRep(498, 498, 1) + //result.addTypeRep(499, 499, 1) + //result.addTypeRep(500, 500, 1) + //result.addTypeRep(501, 501, 1) + //result.addTypeRep(502, 502, 1) + //result.addTypeRep(509, 509, 1) + //result.addTypeRep(510, 510, 1) + //result.addTypeRep(513, 513, 1) + //result.addTypeRep(514, 514, 1) + //result.addTypeRep(516, 516, 1) + //result.addTypeRep(517, 517, 1) + //result.addTypeRep(518, 518, 1) + //result.addTypeRep(519, 519, 1) + //result.addTypeRep(520, 520, 1) + //result.addTypeRep(521, 521, 1) + //result.addTypeRep(522, 522, 1) + //result.addTypeRep(523, 523, 1) + //result.addTypeRep(524, 524, 1) + //result.addTypeRep(525, 525, 1) + //result.addTypeRep(526, 526, 1) + //result.addTypeRep(527, 527, 1) + //result.addTypeRep(528, 528, 1) + //result.addTypeRep(529, 529, 1) + //result.addTypeRep(530, 530, 1) + //result.addTypeRep(531, 531, 1) + //result.addTypeRep(532, 532, 1) + //result.addTypeRep(533, 533, 1) + //result.addTypeRep(534, 534, 1) + //result.addTypeRep(535, 535, 1) + //result.addTypeRep(536, 536, 1) + //result.addTypeRep(537, 537, 1) + //result.addTypeRep(538, 538, 1) + //result.addTypeRep(539, 539, 1) + //result.addTypeRep(540, 540, 1) + //result.addTypeRep(541, 541, 1) + //result.addTypeRep(542, 542, 1) + //result.addTypeRep(543, 543, 1) + //result.addTypeRep(560, 560, 1) + //result.addTypeRep(565, 565, 1) + //result.addTypeRep(572, 572, 1) + //result.addTypeRep(573, 573, 1) + //result.addTypeRep(574, 574, 1) + //result.addTypeRep(575, 575, 1) + //result.addTypeRep(576, 576, 1) + //result.addTypeRep(578, 578, 1) + //result.addTypeRep(580, 580, 1) + //result.addTypeRep(581, 581, 1) + //result.addTypeRep(582, 582, 1) + //result.addTypeRep(583, 583, 1) + //result.addTypeRep(584, 584, 1) + //result.addTypeRep(585, 585, 1) + //result.addTypeRep(3, 2, 10) + //result.addTypeRep(4, 2, 10) + //result.addTypeRep(5, 1, 1) + //result.addTypeRep(6, 2, 10) + //result.addTypeRep(7, 2, 10) + //result.addTypeRep(9, 1, 1) + //result.addTypeRep(13, 0, 0) + //result.addTypeRep(14, 0, 0) + //result.addTypeRep(15, 23, 1) + //result.addTypeRep(16, 0, 0) + //result.addTypeRep(17, 0, 0) + //result.addTypeRep(18, 0, 0) + //result.addTypeRep(19, 0, 0) + //result.addTypeRep(20, 0, 0) + //result.addTypeRep(21, 0, 0) + //result.addTypeRep(22, 0, 0) + //result.addTypeRep(39, 120, 1) + //result.addTypeRep(58, 0, 0) + //result.addTypeRep(68, 2, 10) + //result.addTypeRep(69, 0, 0) + //result.addTypeRep(70, 0, 0) + //result.addTypeRep(74, 0, 0) + //result.addTypeRep(76, 0, 0) + //result.addTypeRep(91, 2, 10) + //result.addTypeRep(94, 1, 1) + //result.addTypeRep(95, 23, 1) + //result.addTypeRep(96, 96, 1) + //result.addTypeRep(97, 96, 1) + //result.addTypeRep(100, 100, 1) + //result.addTypeRep(101, 101, 1) + //result.addTypeRep(102, 102, 1) + //result.addTypeRep(104, 11, 1) + //result.addTypeRep(105, 0, 0) + //result.addTypeRep(106, 106, 1) + //result.addTypeRep(108, 109, 1) + //result.addTypeRep(109, 109, 1) + //result.addTypeRep(110, 111, 1) + //result.addTypeRep(111, 111, 1) + //result.addTypeRep(112, 112, 1) + //result.addTypeRep(113, 113, 1) + //result.addTypeRep(114, 114, 1) + //result.addTypeRep(115, 115, 1) + //result.addTypeRep(116, 102, 1) + //result.addTypeRep(118, 0, 0) + //result.addTypeRep(119, 0, 0) + //result.addTypeRep(121, 0, 0) + //result.addTypeRep(122, 0, 0) + //result.addTypeRep(123, 0, 0) + //result.addTypeRep(136, 0, 0) + //result.addTypeRep(146, 146, 1) + //result.addTypeRep(147, 0, 0) + //result.addTypeRep(152, 2, 10) + //result.addTypeRep(153, 2, 10) + //result.addTypeRep(154, 2, 10) + //result.addTypeRep(155, 1, 1) + //result.addTypeRep(156, 12, 10) + //result.addTypeRep(172, 2, 10) + //result.addTypeRep(178, 178, 1) + //result.addTypeRep(179, 179, 1) + //result.addTypeRep(180, 180, 1) + //result.addTypeRep(181, 181, 1) + //result.addTypeRep(182, 182, 1) + //result.addTypeRep(183, 183, 1) + //result.addTypeRep(184, 12, 10) + //result.addTypeRep(185, 185, 1) + //result.addTypeRep(186, 186, 1) + //result.addTypeRep(187, 187, 1) + //result.addTypeRep(188, 188, 1) + //result.addTypeRep(189, 189, 1) + //result.addTypeRep(190, 190, 1) + //result.addTypeRep(191, 0, 0) + //result.addTypeRep(192, 0, 0) + //result.addTypeRep(195, 112, 1) + //result.addTypeRep(196, 113, 1) + //result.addTypeRep(197, 114, 1) + //result.addTypeRep(208, 208, 1) + //result.addTypeRep(209, 0, 0) + //result.addTypeRep(231, 231, 1) + //result.addTypeRep(232, 231, 1) + //result.addTypeRep(233, 233, 1) + //result.addTypeRep(241, 109, 1) + //result.addTypeRep(515, 0, 0) result.DataTypeRepFor1100 = result.TypeAndRep[0] result.addTypeRep(590, 590, 1) result.addTypeRep(591, 591, 1) result.addTypeRep(592, 592, 1) - if result.Server.ServerCompileTimeCaps != nil && len(result.Server.ServerCompileTimeCaps) > 7 && result.Server.ServerCompileTimeCaps[7] == 5 { - result.RuntimeTypeAndRep = result.TypeAndRep[:result.DataTypeRepFor1100] - } else { + result.addTypeRep(613, 613, 1) + result.addTypeRep(614, 614, 1) + result.addTypeRep(615, 615, 1) + result.addTypeRep(616, 616, 1) + result.addTypeRep(611, 611, 1) + result.addTypeRep(612, 612, 1) + result.addTypeRep(593, 593, 1) + result.addTypeRep(594, 594, 1) + result.addTypeRep(595, 595, 1) + result.addTypeRep(596, 596, 1) + result.addTypeRep(597, 597, 1) + result.addTypeRep(598, 598, 1) + result.addTypeRep(599, 599, 1) + result.addTypeRep(600, 600, 1) + result.addTypeRep(601, 601, 1) + result.addTypeRep(602, 602, 1) + result.addTypeRep(603, 603, 1) + result.addTypeRep(604, 604, 1) + result.addTypeRep(605, 605, 1) + result.addTypeRep(622, 622, 1) + result.addTypeRep(623, 623, 1) + result.addTypeRep(624, 624, 1) + result.addTypeRep(625, 625, 1) + result.addTypeRep(626, 626, 1) + result.addTypeRep(627, 627, 1) + result.addTypeRep(628, 628, 1) + result.addTypeRep(629, 629, 1) + result.addTypeRep(630, 630, 1) + result.addTypeRep(631, 631, 1) + result.addTypeRep(632, 632, 1) + result.addTypeRep(637, 637, 1) + result.addTypeRep(638, 638, 1) + result.addTypeRep(636, 636, 1) + + //result.addTypeRep(590, 590, 1) + //result.addTypeRep(591, 591, 1) + //result.addTypeRep(592, 592, 1) + result.DataTypeRepFor1200 = result.TypeAndRep[0] + result.addTypeRep(639, 639, 1) + result.addTypeRep(640, 640, 1) + + //if result.Server.ServerCompileTimeCaps == nil || len(result.Server.ServerCompileTimeCaps) < 8 { + // return nil, errors.New("server compile time caps length less than 8") + //} + if result.Server.ServerCompileTimeCaps[7] >= 8 { result.RuntimeTypeAndRep = result.TypeAndRep + } else if result.Server.ServerCompileTimeCaps[7] >= 7 { + result.RuntimeTypeAndRep = result.TypeAndRep[:result.DataTypeRepFor1200] + } else { + result.RuntimeTypeAndRep = result.TypeAndRep[:result.DataTypeRepFor1100] } - session.ResetBuffer() - session.PutBytes(result.bytes()...) - err := session.Write() - if err != nil { - return nil, err - } + //if result.Server.ServerCompileTimeCaps != nil && len(result.Server.ServerCompileTimeCaps) > 7 && + // result.Server.ServerCompileTimeCaps[7] == 5 { + // result.RuntimeTypeAndRep = result.TypeAndRep[:result.DataTypeRepFor1100] + //} else { + // result.RuntimeTypeAndRep = result.TypeAndRep + //} + + //getNum := func(session *network.Session, flag uint8) (int, error) { + // if flag == 0 { + // return session.GetInt(1, false, false) + // } else { + // return session.GetInt(2, false, true) + // } + //} + + return &result +} +func (nego *DataTypeNego) read(session *network.Session) error { msg, err := session.GetByte() if err != nil { - return nil, err + return err } if msg != 2 { - return nil, errors.New(fmt.Sprintf("message code error: received code %d and expected code is 2", msg)) + return errors.New(fmt.Sprintf("message code error: received code %d and expected code is 2", msg)) } - - if result.RuntimeCap[1] == 1 { - result.DBTimeZone, err = session.GetBytes(11) + if nego.RuntimeCap[1] == 1 { + nego.DBTimeZone, err = session.GetBytes(11) if err != nil { - return nil, err + return err } - if result.CompileTimeCaps[37]&2 == 2 { + if nego.CompileTimeCaps[37]&2 == 2 { _, _ = session.GetInt(4, false, false) } } - //getNum := func(session *network.Session, flag uint8) (int, error) { - // if flag == 0 { - // return session.GetInt(1, false, false) - // } else { - // return session.GetInt(2, false, true) - // } - //} level := 0 for { var num int - if result.CompileTimeCaps[27] == 0 { + if nego.CompileTimeCaps[27] == 0 { num, err = session.GetInt(1, false, false) } else { num, err = session.GetInt(2, false, true) @@ -435,59 +800,104 @@ func buildTypeNego(nego *TCPNego, session *network.Session) (*DataTypeNego, erro } level++ } - return &result, nil -} -func (nego *DataTypeNego) bytes() []byte { - var result bytes.Buffer - //var result = make([]byte, 7, 1000) + return nil +} +func (nego *DataTypeNego) write(session *network.Session) error { + session.ResetBuffer() if nego.Server.ServerCompileTimeCaps == nil || len(nego.Server.ServerCompileTimeCaps) <= 27 || nego.Server.ServerCompileTimeCaps[27] == 0 { nego.CompileTimeCaps[27] = 0 } - //result.WriteByte(nego.MessageCode) - result.Write([]byte{nego.MessageCode, 0, 0, 0, 0, nego.Server.ServerFlags, uint8(len(nego.CompileTimeCaps))}) - result.Write(nego.CompileTimeCaps) - result.WriteByte(uint8(len(nego.RuntimeCap))) - result.Write(nego.RuntimeCap) - //result[0] = nego.MessageCode - //result[5] = nego.Server.ServerFlags - //result[6] = uint8(len(nego.CompileTimeCaps)) - //result = append(result, nego.CompileTimeCaps...) - //result = append(result, uint8(len(nego.RuntimeCap))) - //result = append(result, nego.RuntimeCap...) + session.PutBytes(nego.MessageCode) + // client remote in + //session.PutBytes(0, 0, 0, 0) + session.PutInt(nego.Server.ServerCharset, 2, false, false) + // client remote out + session.PutInt(nego.Server.ServerCharset, 2, false, false) + session.PutBytes(nego.Server.ServerFlags, uint8(len(nego.CompileTimeCaps))) + session.PutBytes(nego.CompileTimeCaps...) + session.PutBytes(uint8(len(nego.RuntimeCap))) + session.PutBytes(nego.RuntimeCap...) if nego.RuntimeCap[1]&1 == 1 { - result.Write(TZBytes()) - //result = append(result, TZBytes()...) + session.PutBytes(TZBytes()...) if nego.CompileTimeCaps[37]&2 == 2 { - result.Write([]byte{0, 0, 0, 0}) - //result = append(result, []byte{0, 0, 0, 0}...) + session.PutBytes(0, 0, 0, 21) } } - temp := []byte{0, 0} - binary.LittleEndian.PutUint16(temp, uint16(nego.Server.ServernCharset)) - result.Write(temp) - //result = append(result, temp...) + session.PutInt(nego.Server.ServernCharset, 2, false, false) // marshal type reps size := nego.RuntimeTypeAndRep[0] if nego.CompileTimeCaps[27] == 0 { for _, x := range nego.RuntimeTypeAndRep[1:size] { - result.WriteByte(uint8(x)) + session.PutBytes(uint8(x)) + //result.WriteByte(uint8(x)) //result = append(result, uint8(x)) } - result.WriteByte(0) + session.PutBytes(0) + //result.WriteByte(0) //result = append(result, 0) } else { for _, x := range nego.RuntimeTypeAndRep[1:size] { - binary.BigEndian.PutUint16(temp, uint16(x)) + session.PutInt(x, 2, true, false) + //binary.BigEndian.PutUint16(temp, uint16(x)) //result = append(result, temp...) - result.Write(temp) + //result.Write(temp) } - result.Write([]byte{0, 0}) + session.PutBytes(0, 0) + //result.Write([]byte{0, 0}) //result = append(result, []byte{0, 0}...) } - return result.Bytes() + return session.Write() } +//func (nego *DataTypeNego) bytes() []byte { +// //var result bytes.Buffer +// //var result = make([]byte, 7, 1000) +// if nego.Server.ServerCompileTimeCaps == nil || len(nego.Server.ServerCompileTimeCaps) <= 27 || nego.Server.ServerCompileTimeCaps[27] == 0 { +// nego.CompileTimeCaps[27] = 0 +// } +// //result.WriteByte(nego.MessageCode) +// //result.Write([]byte{nego.MessageCode, 0, 0, 0, 0, nego.Server.ServerFlags, uint8(len(nego.CompileTimeCaps))}) +// //result.Write(nego.CompileTimeCaps) +// //result.WriteByte(uint8(len(nego.RuntimeCap))) +// //result.Write(nego.RuntimeCap) +// //result[0] = nego.MessageCode +// //result[5] = nego.Server.ServerFlags +// //result[6] = uint8(len(nego.CompileTimeCaps)) +// //result = append(result, nego.CompileTimeCaps...) +// //result = append(result, uint8(len(nego.RuntimeCap))) +// //result = append(result, nego.RuntimeCap...) +// if nego.RuntimeCap[1]&1 == 1 { +// result.Write(TZBytes()) +// if nego.CompileTimeCaps[37]&2 == 2 { +// result.Write([]byte{0, 0, 0, 21}) +// } +// } +// temp := []byte{0, 0} +// binary.LittleEndian.PutUint16(temp, uint16(nego.Server.ServernCharset)) +// result.Write(temp) +// //result = append(result, temp...) +// // marshal type reps +// size := nego.RuntimeTypeAndRep[0] +// if nego.CompileTimeCaps[27] == 0 { +// for _, x := range nego.RuntimeTypeAndRep[1:size] { +// result.WriteByte(uint8(x)) +// //result = append(result, uint8(x)) +// } +// result.WriteByte(0) +// //result = append(result, 0) +// } else { +// for _, x := range nego.RuntimeTypeAndRep[1:size] { +// binary.BigEndian.PutUint16(temp, uint16(x)) +// //result = append(result, temp...) +// result.Write(temp) +// } +// result.Write([]byte{0, 0}) +// //result = append(result, []byte{0, 0}...) +// } +// return result.Bytes() +//} + func TZBytes() []byte { _, offset := time.Now().Zone() hours := int8(offset / 3600) diff --git a/db_version.go b/db_version.go index d1eff93c..abce4be6 100644 --- a/db_version.go +++ b/db_version.go @@ -28,7 +28,7 @@ func GetDBVersion(session *network.Session) (*DBVersion, error) { if err != nil { return nil, err } - msg, err := session.GetInt(1, false, false) + msg, err := session.GetByte() if msg != 8 { return nil, errors.New(fmt.Sprintf("message code error: received code %d and expected code is 8", msg)) } @@ -36,7 +36,7 @@ func GetDBVersion(session *network.Session) (*DBVersion, error) { if err != nil { return nil, err } - info, err := session.GetBytes(int(length)) + info, err := session.GetString(int(length)) if err != nil { return nil, err } @@ -49,7 +49,7 @@ func GetDBVersion(session *network.Session) (*DBVersion, error) { number>>12&0xF, number>>8&0xF, number&0xFF) ret := &DBVersion{ - Info: string(info), + Info: info, Text: text, Number: uint16(version), MajorVersion: int(number >> 24 & 0xFF), diff --git a/lob.go b/lob.go index 10741800..a867cc6a 100644 --- a/lob.go +++ b/lob.go @@ -3,6 +3,7 @@ package go_ora import ( "bytes" "errors" + "fmt" "github.com/sijms/go-ora/network" ) @@ -29,25 +30,27 @@ func (lob *Lob) littleEndianClob() bool { } return false } -func (lob *Lob) getSize(session *network.Session) (size int64, err error) { +func (lob *Lob) getSize(connection *Connection) (size int64, err error) { + session := connection.session err = lob.write(session, 1) if err != nil { return } - err = lob.read(session) + err = lob.read(connection) if err != nil { return } size = lob.size return } -func (lob *Lob) getData(session *network.Session) (data []byte, err error) { +func (lob *Lob) getData(connection *Connection) (data []byte, err error) { + session := connection.session lob.sourceOffset = 1 err = lob.write(session, 2) if err != nil { return } - err = lob.read(session) + err = lob.read(connection) if err != nil { return } @@ -137,8 +140,9 @@ func (lob *Lob) write(session *network.Session, operationID int) error { return session.Write() } -func (lob *Lob) read(session *network.Session) error { +func (lob *Lob) read(connection *Connection) error { loop := true + session := connection.session for loop { msg, err := session.GetByte() if err != nil { @@ -204,7 +208,25 @@ func (lob *Lob) read(session *network.Session) error { if err != nil { return err } + case 15: + warning, err := network.NewWarningObject(session) + if err != nil { + return err + } + if warning != nil { + fmt.Println(warning) + } + case 23: + opCode, err := session.GetByte() + if err != nil { + return err + } + err = connection.getServerNetworkInformation(opCode) + if err != nil { + return err + } default: + fmt.Println(msg) return errors.New("TTC error") } } @@ -212,24 +234,25 @@ func (lob *Lob) read(session *network.Session) error { } func (lob *Lob) readData(session *network.Session) error { num1 := 0 // data readed in the call of this function - var chunkSize byte = 0 + var chunkSize int = 0 var err error //num3 := offset // the data readed from the start of read operation num4 := 0 for num4 != 4 { switch num4 { case 0: - chunkSize, err = session.GetByte() + nb, err := session.GetByte() if err != nil { return err } + chunkSize = int(nb) if chunkSize == 0xFE { num4 = 2 } else { num4 = 1 } case 1: - chunk, err := session.GetBytes(int(chunkSize)) + chunk, err := session.GetBytes(chunkSize) if err != nil { return err } @@ -237,7 +260,13 @@ func (lob *Lob) readData(session *network.Session) error { num1 += int(chunkSize) num4 = 4 case 2: - chunkSize, err = session.GetByte() + if session.UseBigClrChunks { + chunkSize, err = session.GetInt(4, true, true) + } else { + var nb byte + nb, err = session.GetByte() + chunkSize = int(nb) + } if err != nil { return err } diff --git a/network/accept_packet.go b/network/accept_packet.go index a9b015d8..96f3d461 100644 --- a/network/accept_packet.go +++ b/network/accept_packet.go @@ -1,6 +1,9 @@ package network -import "encoding/binary" +import ( + "encoding/binary" + "fmt" +) //type AcceptPacket Packet type AcceptPacket struct { @@ -10,6 +13,10 @@ type AcceptPacket struct { } func (pck *AcceptPacket) bytes() []byte { + // ptkSize := 41 + // if pck.sessionCtx.Version < 315 { + // ptkSize = 32 + // } output := pck.packet.bytes() //output := make([]byte, pck.dataOffset) //binary.BigEndian.PutUint16(output[0:], pck.packet.length) @@ -17,13 +24,20 @@ func (pck *AcceptPacket) bytes() []byte { //output[5] = pck.packet.flag binary.BigEndian.PutUint16(output[8:], pck.sessionCtx.Version) binary.BigEndian.PutUint16(output[10:], pck.sessionCtx.Options) - binary.BigEndian.PutUint16(output[12:], pck.sessionCtx.SessionDataUnit) - binary.BigEndian.PutUint16(output[14:], pck.sessionCtx.TransportDataUnit) + if pck.sessionCtx.Version < 315 { + binary.BigEndian.PutUint16(output[12:], uint16(pck.sessionCtx.SessionDataUnit)) + binary.BigEndian.PutUint16(output[14:], uint16(pck.sessionCtx.TransportDataUnit)) + } else { + binary.BigEndian.PutUint32(output[32:], pck.sessionCtx.SessionDataUnit) + binary.BigEndian.PutUint32(output[36:], pck.sessionCtx.TransportDataUnit) + } + binary.BigEndian.PutUint16(output[16:], pck.sessionCtx.Histone) binary.BigEndian.PutUint16(output[18:], uint16(len(pck.buffer))) binary.BigEndian.PutUint16(output[20:], pck.packet.dataOffset) output[22] = pck.sessionCtx.ACFL0 output[23] = pck.sessionCtx.ACFL1 + // s output = append(output, pck.buffer...) return output } @@ -64,12 +78,12 @@ func newAcceptPacketFromData(packetData []byte) *AcceptPacket { pck := AcceptPacket{ packet: Packet{ dataOffset: binary.BigEndian.Uint16(packetData[20:]), - length: binary.BigEndian.Uint16(packetData), + length: uint32(binary.BigEndian.Uint16(packetData)), packetType: PacketType(packetData[4]), flag: packetData[5], }, sessionCtx: SessionContext{ - connOption: ConnectionOption{}, + ConnOption: ConnectionOption{}, SID: nil, Version: binary.BigEndian.Uint16(packetData[8:]), LoVersion: 0, @@ -80,24 +94,44 @@ func newAcceptPacketFromData(packetData []byte) *AcceptPacket { ReconAddr: reconAdd, ACFL0: packetData[22], ACFL1: packetData[23], - SessionDataUnit: binary.BigEndian.Uint16(packetData[12:]), - TransportDataUnit: binary.BigEndian.Uint16(packetData[14:]), + SessionDataUnit: uint32(binary.BigEndian.Uint16(packetData[12:])), + TransportDataUnit: uint32(binary.BigEndian.Uint16(packetData[14:])), UsingAsyncReceivers: false, IsNTConnected: false, OnBreakReset: false, GotReset: false, }, - buffer: packetData[32:], } + pck.buffer = packetData[int(pck.packet.dataOffset):] + if pck.sessionCtx.Version > 315 { + pck.sessionCtx.SessionDataUnit = binary.BigEndian.Uint32(packetData[32:]) + pck.sessionCtx.TransportDataUnit = binary.BigEndian.Uint32(packetData[36:]) + } + if (pck.packet.flag & 1) > 0 { + fmt.Println("contain SID data") + pck.packet.length -= 16 + pck.sessionCtx.SID = packetData[int(pck.packet.length):] + } + if pck.sessionCtx.TransportDataUnit < pck.sessionCtx.SessionDataUnit { + pck.sessionCtx.SessionDataUnit = pck.sessionCtx.TransportDataUnit + } + // if pck.sessionCtx.Version >= 310 { + // byteCount := binary.BigEndian.Uint16(packetData[28:]) + // if byteCount > 0 { + // byteOffset := binary.BigEndian.Uint16(packetData[30:]) + // } + // } + // dataLen := binary.BigEndian.Uint16(packetData[18:]) //if pck.length != uint16(len(packetData)) { // return nil //} //if pck.packetType != ACCEPT { // return nil //} - if pck.packet.dataOffset != 32 { - return nil - } + // if pck.packet.dataOffset != 32 { + // fmt.Println("data offset: ", pck.packet.dataOffset) + // return nil + // } if binary.BigEndian.Uint16(packetData[18:]) != uint16(len(pck.buffer)) { return nil } diff --git a/network/connect_option.go b/network/connect_option.go index 58abb35f..fe5d7297 100644 --- a/network/connect_option.go +++ b/network/connect_option.go @@ -19,8 +19,8 @@ type ConnectionOption struct { TransportConnectTo int SSLVersion string WalletDict string - TransportDataUnitSize uint16 - SessionDataUnitSize uint16 + TransportDataUnitSize uint32 + SessionDataUnitSize uint32 Protocol string Host string UserID string @@ -34,8 +34,9 @@ type ConnectionOption struct { DBName string ClientData ClientData //InAddrAny bool - Tracer trace.Tracer - connData string + Tracer trace.Tracer + connData string + SNOConfig map[string]string } func (op *ConnectionOption) ConnectionData() string { diff --git a/network/connect_packet.go b/network/connect_packet.go index 6138aab4..aeb13b9c 100644 --- a/network/connect_packet.go +++ b/network/connect_packet.go @@ -17,11 +17,22 @@ func (pck *ConnectPacket) bytes() []byte { binary.BigEndian.PutUint16(output[8:], pck.sessionCtx.Version) binary.BigEndian.PutUint16(output[10:], pck.sessionCtx.LoVersion) binary.BigEndian.PutUint16(output[12:], pck.sessionCtx.Options) - binary.BigEndian.PutUint16(output[14:], pck.sessionCtx.SessionDataUnit) - binary.BigEndian.PutUint16(output[16:], pck.sessionCtx.TransportDataUnit) + num := uint16(pck.sessionCtx.SessionDataUnit) + if pck.sessionCtx.SessionDataUnit > 0xFFFF { + num = 0xFFFF + } + binary.BigEndian.PutUint16(output[14:], num) + binary.BigEndian.PutUint32(output[58:], pck.sessionCtx.SessionDataUnit) + num = uint16(pck.sessionCtx.TransportDataUnit) + if pck.sessionCtx.TransportDataUnit > 0xFFFF { + num = 0xFFFF + } + binary.BigEndian.PutUint16(output[16:], num) + binary.BigEndian.PutUint32(output[62:], pck.sessionCtx.TransportDataUnit) + binary.BigEndian.PutUint32(output[66:], 0) output[18] = 79 output[19] = 152 - binary.BigEndian.PutUint16(output[22:], pck.sessionCtx.Histone) + binary.BigEndian.PutUint16(output[22:], pck.sessionCtx.OurOne) binary.BigEndian.PutUint16(output[24:], uint16(len(pck.buffer))) binary.BigEndian.PutUint16(output[26:], pck.packet.dataOffset) output[32] = pck.sessionCtx.ACFL0 @@ -36,21 +47,23 @@ func (pck *ConnectPacket) getPacketType() PacketType { return pck.packet.packetType } func newConnectPacket(sessionCtx SessionContext) *ConnectPacket { - connectData := sessionCtx.connOption.ConnectionData() - length := uint16(len(connectData)) + connectData := sessionCtx.ConnOption.ConnectionData() + length := uint32(len(connectData)) if length > 230 { length = 0 } - length += 58 + length += 70 sessionCtx.Histone = 1 - sessionCtx.ACFL0 = 4 - sessionCtx.ACFL1 = 4 + sessionCtx.ACFL0 = 1 + sessionCtx.ACFL1 = 1 + //sessionCtx.ACFL0 = 4 + //sessionCtx.ACFL1 = 4 return &ConnectPacket{ sessionCtx: sessionCtx, packet: Packet{ - dataOffset: 58, + dataOffset: 70, length: length, packetType: CONNECT, flag: 0, diff --git a/network/data_packet.go b/network/data_packet.go index 32b7968b..229c9723 100644 --- a/network/data_packet.go +++ b/network/data_packet.go @@ -1,54 +1,70 @@ package network import ( + "bytes" "encoding/binary" ) type DataPacket struct { - packet Packet - dataFlag uint16 - buffer []byte + Packet + sessionCtx *SessionContext + dataFlag uint16 + buffer []byte } func (pck *DataPacket) bytes() []byte { - output := pck.packet.bytes() - binary.BigEndian.PutUint16(output[8:], pck.dataFlag) + output := bytes.Buffer{} + temp := make([]byte, 0xA) + if pck.sessionCtx.handshakeComplete && pck.sessionCtx.Version >= 315 { + binary.BigEndian.PutUint32(temp, pck.length) + } else { + binary.BigEndian.PutUint16(temp, uint16(pck.length)) + } + temp[4] = uint8(pck.packetType) + temp[5] = pck.flag + binary.BigEndian.PutUint16(temp[8:], pck.dataFlag) + output.Write(temp) if len(pck.buffer) > 0 { - output = append(output, pck.buffer...) + output.Write(pck.buffer) } - return output + return output.Bytes() } -func (pck *DataPacket) getPacketType() PacketType { - return pck.packet.packetType -} -func newDataPacket(initialData []byte) *DataPacket { +func newDataPacket(initialData []byte, sessionCtx *SessionContext) *DataPacket { return &DataPacket{ - packet: Packet{ + Packet: Packet{ dataOffset: 0xA, - length: uint16(len(initialData) + 0xA), + length: uint32(len(initialData)) + 0xA, packetType: DATA, flag: 0, }, - dataFlag: 0, - buffer: initialData, + sessionCtx: sessionCtx, + dataFlag: 0, + buffer: initialData, } } -func newDataPacketFromData(packetData []byte) *DataPacket { +func newDataPacketFromData(packetData []byte, sessionCtx *SessionContext) *DataPacket { if len(packetData) <= 0xA || PacketType(packetData[4]) != DATA { return nil } - return &DataPacket{ - packet: Packet{ + pck := &DataPacket{ + Packet: Packet{ dataOffset: 0xA, - length: binary.BigEndian.Uint16(packetData), + //length: binary.BigEndian.Uint16(packetData), packetType: PacketType(packetData[4]), flag: packetData[5], }, - dataFlag: binary.BigEndian.Uint16(packetData[8:]), - buffer: packetData[10:], + sessionCtx: sessionCtx, + dataFlag: binary.BigEndian.Uint16(packetData[8:]), + buffer: packetData[10:], + } + if sessionCtx.handshakeComplete && sessionCtx.Version >= 315 { + pck.length = binary.BigEndian.Uint32(packetData) + } else { + pck.length = uint32(binary.BigEndian.Uint16(packetData)) } + return pck } //func (pck *DataPacket) Data() []byte { diff --git a/network/marker_packet.go b/network/marker_packet.go index bbacae06..3703bb7a 100644 --- a/network/marker_packet.go +++ b/network/marker_packet.go @@ -3,7 +3,8 @@ package network import "encoding/binary" type MarkerPacket struct { - packet Packet + packet Packet + sessionCtx *SessionContext //length uint16 //packetType PacketType //flag uint8 @@ -12,14 +13,18 @@ type MarkerPacket struct { } func (pck *MarkerPacket) bytes() []byte { - return []byte{0, 0xB, 0, 0, 0xC, 0, 0, 0, pck.markerType, 0, pck.markerData} + if pck.sessionCtx.handshakeComplete && pck.sessionCtx.Version >= 315 { + return []byte{0, 0x0, 0, 0xB, 0xC, 0, 0, 0, pck.markerType, 0, pck.markerData} + } else { + return []byte{0, 0xB, 0, 0, 0xC, 0, 0, 0, pck.markerType, 0, pck.markerData} + } } func (pck *MarkerPacket) getPacketType() PacketType { return pck.packet.packetType } -func newMarkerPacket(markerData uint8) *MarkerPacket { +func newMarkerPacket(markerData uint8, sessionCtx *SessionContext) *MarkerPacket { return &MarkerPacket{ packet: Packet{ dataOffset: 0, @@ -27,24 +32,30 @@ func newMarkerPacket(markerData uint8) *MarkerPacket { packetType: MARKER, flag: 0, }, + sessionCtx: sessionCtx, markerType: 1, markerData: markerData, } } -func newMarkerPacketFromData(packetData []byte) *MarkerPacket { +func newMarkerPacketFromData(packetData []byte, sessionCtx *SessionContext) *MarkerPacket { if len(packetData) != 0xB { return nil } pck := MarkerPacket{ packet: Packet{ dataOffset: 0, - length: binary.BigEndian.Uint16(packetData), packetType: PacketType(packetData[4]), flag: packetData[5], }, + sessionCtx: sessionCtx, markerType: packetData[8], markerData: packetData[10], } + if sessionCtx.handshakeComplete && sessionCtx.Version >= 315 { + pck.packet.length = binary.BigEndian.Uint32(packetData) + } else { + pck.packet.length = uint32(binary.BigEndian.Uint16(packetData)) + } if pck.packet.packetType != MARKER { return nil } diff --git a/network/packets.go b/network/packets.go index c437912e..eeb1071c 100644 --- a/network/packets.go +++ b/network/packets.go @@ -28,7 +28,7 @@ const ( type Packet struct { //sessionCtx SessionContext dataOffset uint16 - length uint16 + length uint32 packetType PacketType flag uint8 //NSPFSID int @@ -47,7 +47,7 @@ type Packet struct { func newPacket(packetData []byte) *Packet { return &Packet{ - length: binary.BigEndian.Uint16(packetData), + length: uint32(binary.BigEndian.Uint16(packetData)), packetType: PacketType(packetData[4]), flag: packetData[5], } @@ -58,7 +58,7 @@ func (pck *Packet) bytes() []byte { if pck.dataOffset > 8 { output = append(output, make([]byte, pck.dataOffset-8)...) } - binary.BigEndian.PutUint16(output, pck.length) + binary.BigEndian.PutUint16(output, uint16(pck.length)) output[4] = uint8(pck.packetType) output[5] = pck.flag return output diff --git a/network/redirect_packet.go b/network/redirect_packet.go index 69f262a2..8ed3e069 100644 --- a/network/redirect_packet.go +++ b/network/redirect_packet.go @@ -31,7 +31,7 @@ func newRedirectPacketFromData(packetData []byte) *RedirectPacket { pck := RedirectPacket{ packet: Packet{ dataOffset: 10, - length: binary.BigEndian.Uint16(packetData), + length: uint32(binary.BigEndian.Uint16(packetData)), packetType: PacketType(packetData[4]), flag: packetData[5], }, diff --git a/network/refuse_packet.go b/network/refuse_packet.go index 7dbffb62..6d6c2382 100644 --- a/network/refuse_packet.go +++ b/network/refuse_packet.go @@ -38,7 +38,7 @@ func newRefusePacketFromData(packetData []byte) *RefusePacket { return &RefusePacket{ packet: Packet{ dataOffset: 12, - length: binary.BigEndian.Uint16(packetData), + length: uint32(binary.BigEndian.Uint16(packetData)), packetType: PacketType(packetData[4]), flag: 0, }, diff --git a/network/security/diffie_hellman.go b/network/security/diffie_hellman.go new file mode 100644 index 00000000..e292e202 --- /dev/null +++ b/network/security/diffie_hellman.go @@ -0,0 +1,29 @@ +package security + +type DiffieHellman struct { + buffer_1 []byte + buffer_2 []byte + size_1 int + size_2 int +} + +func NewDiffieHellman(buffer1, buffer2 []byte, size1, size2 int) *DiffieHellman { + return &DiffieHellman{ + buffer_1: buffer1, + buffer_2: buffer2, + size_1: size1, + size_2: size2, + } +} + +//func (dh *DiffieHellman) GetPublicKey() { +// array_1 := make([]int, 257) +// array_2 := make([]int, 257) +// array_3 := make([]uint8, 512) +// num := (dh.size_2 + 7) >> 3 +// m := (dh.size_1 + 7) >> 3 +// n := (dh.size_1 / 16) + 1 +// l := make([]byte, m) +// +// // calling function b(array_3, num) +//} diff --git a/network/session.go b/network/session.go index 6ba89a73..ba18b3ad 100644 --- a/network/session.go +++ b/network/session.go @@ -42,16 +42,20 @@ type Session struct { Summary *SummaryObject states []sessionState StrConv *converters.StringConverter + UseBigClrChunks bool + ClrChunkSize int } func NewSession(connOption ConnectionOption) *Session { return &Session{ - conn: nil, - inBuffer: nil, - index: 0, - connOption: connOption, - Context: NewSessionContext(connOption), - Summary: nil, + conn: nil, + inBuffer: nil, + index: 0, + connOption: connOption, + Context: NewSessionContext(connOption), + Summary: nil, + UseBigClrChunks: false, + ClrChunkSize: 0x40, } } @@ -103,7 +107,7 @@ func (session *Session) Connect() error { if err != nil { return err } - if connectPacket.packet.length == 58 { + if uint16(connectPacket.packet.length) == connectPacket.packet.dataOffset { session.PutBytes(connectPacket.buffer...) err = session.Write() if err != nil { @@ -117,6 +121,16 @@ func (session *Session) Connect() error { if acceptPacket, ok := pck.(*AcceptPacket); ok { *session.Context = acceptPacket.sessionCtx + session.Context.handshakeComplete = true + + //if (this.m_sessionCtx.m_ACFL0 & 1) != 0 && + // (this.m_sessionCtx.m_ACFL0 & 4) == 0 && + // (this.m_sessionCtx.m_ACFL1 & 8) == 0 { + // this.m_sessionCtx.m_ano.StartNegotiation(); + // } else { + // this.m_sessionCtx.m_bAnoEnabled = false; + // this.m_sessionCtx.m_ano = (Ano) null; + // } return nil } if redirectPacket, ok := pck.(*RedirectPacket); ok { @@ -199,14 +213,14 @@ func (session *Session) Write() error { size := session.outBuffer.Len() if size == 0 { // send empty data packet - return session.writePacket(newDataPacket(nil)) + return session.writePacket(newDataPacket(nil, session.Context)) //return errors.New("the output buffer is empty") } segment := int(session.Context.SessionDataUnit - 20) offset := 0 for size > segment { - err := session.writePacket(newDataPacket(outputBytes[offset : offset+segment])) + err := session.writePacket(newDataPacket(outputBytes[offset:offset+segment], session.Context)) if err != nil { session.outBuffer.Reset() return err @@ -215,7 +229,7 @@ func (session *Session) Write() error { offset += segment } if size != 0 { - err := session.writePacket(newDataPacket(outputBytes[offset:])) + err := session.writePacket(newDataPacket(outputBytes[offset:], session.Context)) if err != nil { session.outBuffer.Reset() return err @@ -285,22 +299,28 @@ func (session *Session) readPacket() (PacketInterface, error) { if err != nil { return nil, err } - length := binary.BigEndian.Uint16(head) + pckType := PacketType(head[4]) + var length uint32 + if session.Context.handshakeComplete && session.Context.Version >= 315 { + length = binary.BigEndian.Uint32(head) + } else { + length = uint32(binary.BigEndian.Uint16(head)) + } length -= 8 body := make([]byte, length) - index := uint16(0) + index := uint32(0) for index < length { temp, err := conn.Read(body[index:]) if err != nil { if e, ok := err.(net.Error); ok && e.Timeout() && temp != 0 { - index += uint16(temp) + index += uint32(temp) continue } return nil, err } - index += uint16(temp) + index += uint32(temp) } - pckType := PacketType(head[4]) + if pckType == RESEND { for _, pck := range session.sendPcks { //log.Printf("Request: %#v\n\n", pck.bytes()) @@ -333,9 +353,9 @@ func (session *Session) readPacket() (PacketInterface, error) { pck := newRedirectPacketFromData(packetData) dataLen := binary.BigEndian.Uint16(packetData[8:]) var data string - if pck.packet.length <= pck.packet.dataOffset { + if uint16(pck.packet.length) <= pck.packet.dataOffset { packetData, err = readPacketData(session.conn) - dataPck := newDataPacketFromData(packetData) + dataPck := newDataPacketFromData(packetData, session.Context) data = string(dataPck.buffer) } else { data = string(packetData[10 : 10+dataLen]) @@ -358,9 +378,9 @@ func (session *Session) readPacket() (PacketInterface, error) { // connect through redirectConnectData return pck, nil case DATA: - return newDataPacketFromData(packetData), nil + return newDataPacketFromData(packetData, session.Context), nil case MARKER: - pck := newMarkerPacketFromData(packetData) + pck := newMarkerPacketFromData(packetData, session.Context) breakConnection := false resetConnection := false switch pck.markerType { @@ -384,7 +404,7 @@ func (session *Session) readPacket() (PacketInterface, error) { if err != nil { return nil, err } - pck = newMarkerPacketFromData(packetData) + pck = newMarkerPacketFromData(packetData, session.Context) if pck == nil { return nil, errors.New("connection break") } @@ -403,7 +423,7 @@ func (session *Session) readPacket() (PacketInterface, error) { trials++ } session.ResetBuffer() - err = session.writePacket(newMarkerPacket(2)) + err = session.writePacket(newMarkerPacket(2, session.Context)) if err != nil { return nil, err } @@ -411,7 +431,7 @@ func (session *Session) readPacket() (PacketInterface, error) { if err != nil { return nil, err } - dataPck := newDataPacketFromData(packetData) + dataPck := newDataPacketFromData(packetData, session.Context) if dataPck == nil { return nil, errors.New("connection break") } @@ -436,6 +456,15 @@ func (session *Session) readPacket() (PacketInterface, error) { } } +func (session *Session) PutString(data string) { + session.PutClr([]byte(data)) +} + +func (session *Session) GetString(length int) (string, error) { + ret, err := session.GetClr() + return string(ret[:length]), err +} + func (session *Session) PutBytes(data ...byte) { session.outBuffer.Write(data) //session.outBuffer = append(session.outBuffer, ) @@ -599,31 +628,29 @@ func (session *Session) PutInt(number interface{}, size uint8, bigEndian bool, c func (session *Session) PutClr(data []byte) { dataLen := len(data) - if dataLen == 0 { - session.outBuffer.WriteByte(0) - //session.outBuffer = append(session.outBuffer, 0) - return - } - if dataLen > 0x40 { + if dataLen > 0xFC { session.outBuffer.WriteByte(0xFE) - //session.outBuffer = append(session.outBuffer, 0xFE) - } - start := 0 - for start < dataLen { - end := start + 0x40 - if end > dataLen { - end = dataLen + start := 0 + for start < dataLen { + end := start + session.ClrChunkSize + if end > dataLen { + end = dataLen + } + temp := data[start:end] + if session.UseBigClrChunks { + session.PutInt(len(temp), 4, true, true) + } else { + session.outBuffer.WriteByte(uint8(len(temp))) + } + session.outBuffer.Write(temp) + start += session.ClrChunkSize } - temp := data[start:end] - session.outBuffer.WriteByte(uint8(len(temp))) - session.outBuffer.Write(temp) - //session.outBuffer = append(session.outBuffer, uint8(len(temp))) - //session.outBuffer = append(session.outBuffer, temp...) - start += 64 - } - if dataLen > 0x40 { session.outBuffer.WriteByte(0) - //session.outBuffer = append(session.outBuffer, 0) + } else if dataLen == 0 { + session.outBuffer.WriteByte(0) + } else { + session.outBuffer.WriteByte(uint8(len(data))) + session.outBuffer.Write(data) } } @@ -729,10 +756,10 @@ func (session *Session) GetClr() (output []byte, err error) { if err != nil { return } - if size == 253 { - err = errors.New("TTC error") - return - } + //if size == 253 { + // err = errors.New("TTC error") + // return + //} if size == 0 || size == 0xFF { output = nil err = nil @@ -745,12 +772,16 @@ func (session *Session) GetClr() (output []byte, err error) { //output = make([]byte, 0, 1000) var tempBuffer bytes.Buffer for { - var size1 uint8 - size1, err = session.GetByte() + var size1 int + if session.UseBigClrChunks { + size1, err = session.GetInt(4, true, true) + } else { + size1, err = session.GetInt(1, true, true) + } if err != nil || size1 == 0 { break } - rb, err = session.read(int(size1)) + rb, err = session.read(size1) if err != nil { return } @@ -792,62 +823,3 @@ func (session *Session) GetKeyVal() (key []byte, val []byte, num int, err error) num, err = session.GetInt(4, true, true) return } - -//func (session *Session) DoAuth(logonMode int) error{ -// index := strings.LastIndex(session.connOption.ClientData.ProgramName, "/") -// if index < 0 { -// index = 0 -// } else { -// index += 1 -// } -// ikeys := []string{"AUTH_TERMINAL", "AUTH_PROGRAM_NM", "AUTH_MACHINE", "AUTH_PID", "AUTH_SID"} -// ivals := []string{ -// session.connOption.ClientData.HostName, -// session.connOption.ClientData.ProgramName[index:], -// session.connOption.ClientData.HostName, -// fmt.Sprintf("%d", session.connOption.ClientData.PID), -// session.connOption.ClientData.UserName, -// } -// inums := []int{0, 0, 0, 0, 0} -// -// var pck = newDataPacket([]byte {3, 118, 0, 1}) // message_code, function_code, sequence_number, 1 -// pck.AppendInt(len(session.connOption.UserID), 4, false, true) -// pck.AppendInt(logonMode | 1, 4, false, true) -// pck.AppendBytes([]byte{1, 1, 5, 1, 1}, false) -// pck.AppendBytes([]byte(session.connOption.UserID), false) -// pck.AppendKeyVal(ikeys, ivals, inums) -// authData, err := session.SendData(pck.Data()) -// if err != nil { -// return err -// } -// rPck := newDataPacket(authData) -// messageCode, err := rPck.ReadInt(1, false, false) -// if err != nil { -// return err -// } -// if messageCode != 8 { -// return errors.New(fmt.Sprintf("message code error: received code %d and expected code is 8", messageCode)) -// } -// dictLen, err := rPck.ReadInt(4, true, true) -// if err != nil { -// return err -// } -// keys, vals, nums, err := rPck.ReadKeyVal(int(dictLen)) -// if err != nil { -// fmt.Println(err) -// return err -// } -// for x:=0; x < len(keys); x++ { -// if bytes.Compare(keys[x], []byte("AUTH_SESSKEY")) == 0 { -// session.key = vals[x] -// } else if bytes.Compare(keys[x], []byte("AUTH_VFR_DATA")) == 0 { -// session.salt = vals[x] -// session.verifierType = nums[x] -// } -// } -// if len(session.key) != 64 && len(session.key) != 96 { -// return errors.New("TCC Error: SessionKey should be either 64 or 96 bytes long.") -// } -// // load the error object -// return nil -//} diff --git a/network/session_ctx.go b/network/session_ctx.go index 12d1a959..b9309c22 100644 --- a/network/session_ctx.go +++ b/network/session_ctx.go @@ -11,7 +11,7 @@ package network type SessionContext struct { //conn net.Conn - connOption ConnectionOption + ConnOption ConnectionOption //PortNo int //InstanceName string //HostName string @@ -25,34 +25,34 @@ type SessionContext struct { //internal WriterStream m_writerStream; //internal ITransportAdapter m_transportAdapter; //ConnectData string - Version uint16 - LoVersion uint16 - Options uint16 + Version uint16 + LoVersion uint16 + Options uint16 NegotiatedOptions uint16 - OurOne uint16 - Histone uint16 - ReconAddr string + OurOne uint16 + Histone uint16 + ReconAddr string + handshakeComplete bool //internal Ano m_ano; //internal bool m_bAnoEnabled; - ACFL0 uint8 - ACFL1 uint8 - SessionDataUnit uint16 - TransportDataUnit uint16 + ACFL0 uint8 + ACFL1 uint8 + SessionDataUnit uint32 + TransportDataUnit uint32 UsingAsyncReceivers bool - IsNTConnected bool - OnBreakReset bool - GotReset bool + IsNTConnected bool + OnBreakReset bool + GotReset bool } func NewSessionContext(connOption ConnectionOption) *SessionContext { return &SessionContext{ SessionDataUnit: connOption.SessionDataUnitSize, TransportDataUnit: connOption.TransportDataUnitSize, - Version: 312, + Version: 317, LoVersion: 300, - Options: 1 | 1024 | 2048, + Options: 1 | 1024 | 2048, /*1024 for urgent data transport*/ OurOne: 1, - connOption: connOption, + ConnOption: connOption, } } - diff --git a/network/summary_object.go b/network/summary_object.go index e52f9255..ae572b1d 100644 --- a/network/summary_object.go +++ b/network/summary_object.go @@ -151,7 +151,11 @@ func NewSummary(session *Session) (*SummaryObject, error) { flag := num == 0xFE for x := 0; x < length; x++ { if flag { - _, _ = session.GetByte() + if session.UseBigClrChunks { + _, _ = session.GetInt(4, true, true) + } else { + _, _ = session.GetByte() + } } result.bindErrors[x].errorCode, err = session.GetInt(2, true, true) if err != nil { @@ -174,7 +178,11 @@ func NewSummary(session *Session) (*SummaryObject, error) { flag := num == 0xFE for x := 0; x < length; x++ { if flag { - _, _ = session.GetByte() + if session.UseBigClrChunks { + _, _ = session.GetInt(4, true, true) + } else { + _, _ = session.GetByte() + } } result.bindErrors[x].rowOffset, err = session.GetInt(4, true, true) if err != nil { @@ -204,6 +212,16 @@ func NewSummary(session *Session) (*SummaryObject, error) { _, _ = session.GetByte() } } + if session.TTCVersion >= 7 { + result.RetCode, err = session.GetInt(4, true, true) + if err != nil { + return nil, err + } + result.CurRowNumber, err = session.GetInt(8, true, true) + if err != nil { + return nil, err + } + } if result.RetCode != 0 { result.ErrorMessage, err = session.GetClr() if err != nil { diff --git a/parameter.go b/parameter.go index f8cf29eb..9569a03b 100644 --- a/parameter.go +++ b/parameter.go @@ -105,6 +105,7 @@ type ParameterInfo struct { BValue []byte Value driver.Value getDataFromServer bool + oaccollid int } func (par *ParameterInfo) load(session *network.Session) error { @@ -191,7 +192,11 @@ func (par *ParameterInfo) load(session *network.Session) error { if err != nil { return err } - par.ContFlag, err = session.GetInt(4, true, true) + if session.TTCVersion >= 10 { + par.ContFlag, err = session.GetInt(8, true, true) + } else { + par.ContFlag, err = session.GetInt(4, true, true) + } if err != nil { return err } @@ -212,6 +217,9 @@ func (par *ParameterInfo) load(session *network.Session) error { if err != nil { return err } + if session.TTCVersion >= 8 { + par.oaccollid, err = session.GetInt(4, true, true) + } num1, err := session.GetInt(1, false, false) if err != nil { return err @@ -253,7 +261,11 @@ func (par *ParameterInfo) write(session *network.Session) error { //session.PutUint(par.Scale, 1, false, false) session.PutUint(par.MaxLen, 4, true, true) session.PutInt(par.MaxNoOfArrayElements, 4, true, true) - session.PutInt(par.ContFlag, 4, true, true) + if session.TTCVersion >= 10 { + session.PutInt(par.ContFlag, 8, true, true) + } else { + session.PutInt(par.ContFlag, 4, true, true) + } if par.ToID == nil { session.PutBytes(0) //session.PutInt(0, 1, false, false) @@ -266,6 +278,9 @@ func (par *ParameterInfo) write(session *network.Session) error { session.PutBytes(uint8(par.CharsetForm)) //session.PutUint(par.CharsetForm, 1, false, false) session.PutUint(par.MaxCharLen, 4, true, true) + if session.TTCVersion >= 8 { + session.PutInt(par.oaccollid, 4, true, true) + } return nil } diff --git a/ref_cursor.go b/ref_cursor.go index 8999d68f..f654f252 100644 --- a/ref_cursor.go +++ b/ref_cursor.go @@ -10,7 +10,7 @@ type RefCursor struct { MaxRowSize int parent *defaultStmt //ID int - //scnFromExe []int + //scnForSnapshot []int //connection *Connection //noOfRowsToFetch int //hasMoreRows bool @@ -24,7 +24,7 @@ func (cursor *RefCursor) load(session *network.Session) error { cursor._hasReturnClause = false cursor.disableCompression = true cursor.arrayBindCount = 1 - cursor.scnFromExe = make([]int, 2) + cursor.scnForSnapshot = make([]int, 2) cursor.stmtType = SELECT var err error cursor.len, err = session.GetByte() @@ -106,8 +106,8 @@ func (cursor *RefCursor) Query() (*DataSet, error) { cursor.connection.connOption.Tracer.Printf("Query RefCursor: %d", cursor.cursorID) cursor._noOfRowsToFetch = 25 cursor._hasMoreRows = true - if len(cursor.parent.scnFromExe) > 0 { - copy(cursor.scnFromExe, cursor.parent.scnFromExe) + if len(cursor.parent.scnForSnapshot) > 0 { + copy(cursor.scnForSnapshot, cursor.parent.scnForSnapshot) } session := cursor.connection.session session.ResetBuffer() diff --git a/tcp_protocol_nego.go b/tcp_protocol_nego.go index 6c90a254..5240df54 100644 --- a/tcp_protocol_nego.go +++ b/tcp_protocol_nego.go @@ -104,5 +104,15 @@ func NewTCPNego(session *network.Session) (*TCPNego, error) { if result.ServerCompileTimeCaps[16]&1 != 0 { session.HasFSAPCapability = true } + if result.ServerCompileTimeCaps == nil || len(result.ServerCompileTimeCaps) < 8 { + return nil, errors.New("server compile time caps length less than 8") + } + if len(result.ServerCompileTimeCaps) > 37 && result.ServerCompileTimeCaps[37]&32 != 0 { + session.UseBigClrChunks = true + session.ClrChunkSize = 0x7FFF + } + //this.m_b32kTypeSupported = this.m_dtyNeg.m_b32kTypeSupported; + //this.m_bSupportSessionStateOps = this.m_dtyNeg.m_bSupportSessionStateOps; + //this.m_marshallingEngine.m_bServerUsingBigSCN = this.m_serverCompiletimeCapabilities[7] >= (byte) 8; return &result, nil }