// codecvt facet for JIS multibyte code, UCS-2 wide-character code

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once
#ifndef _CVT_XJIS_
#define _CVT_XJIS_
#include <yvals_core.h>
#if _STL_COMPILER_PREPROCESSOR
#include <cwchar>
#include <locale>

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
_STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

#ifndef _CVT_JIS_
#error do not include directly
#endif // _CVT_JIS_

namespace stdext {
    namespace cvt {

        using _Statype = _CSTD mbstate_t;

#define _ESC_CODE 0x1b

        template <class _Elem, unsigned long _Maxcode = 0xffff>
        class codecvt_jis : public _STD codecvt<_Elem, char, _Statype> {
            // facet for converting between JIS-X0208 _Elem and UCS-2 bytes
            using _Table = _tab_jis<int>;

            bool _Jis_to_ucs(unsigned short& _Ch) const { // convert _Ch from JIS-X0208 to UCS-2
                if (_Ch < _Table::_Nlow) { // map byte to same wide value
                    return true;
                }

                if (_Ch < 0x100) { // map a one-byte code
                    if ((_Ch = _Table::_Btw[_Ch - _Table::_Nlow]) == 0) {
                        return false; // no defined mapping for byte
                    }
                } else { // search for a valid mapping
                    unsigned long _Lo = 0;
                    unsigned long _Hi = static_cast<unsigned long>(_STD size(_Table::_Dbvalid));

                    while (_Lo < _Hi) { // test midpoint of remaining interval
                        unsigned long _Mid = (_Hi + _Lo) / 2;

                        if (_Ch < _Table::_Dbvalid[_Mid]) {
                            _Hi = _Mid;
                        } else if (_Ch == _Table::_Dbvalid[_Mid]) { // found the code, map it
                            _Ch = _Table::_Dbtw[_Mid];
                            return true;
                        } else {
                            _Lo = _Mid + 1;
                        }
                    }
                    return false;
                }
                return true;
            }

            bool _Ucs_to_jis(unsigned short& _Ch) const { // convert _Ch from UCS-2 to JIS-X0208
                if (_Ch < _Table::_Nlow
                    || (_Ch < 0x100 && _Table::_Btw[_Ch - _Table::_Nlow] == _Ch)) { // map wide to same byte value
                    return true;
                }

                // search for a valid mapping
                unsigned long _Lo = 0;
                unsigned long _Hi = static_cast<unsigned long>(_STD size(_Table::_Wvalid));

                while (_Lo < _Hi) { // test midpoint of remaining interval
                    unsigned long _Mid = (_Hi + _Lo) / 2;

                    if (_Ch < _Table::_Wvalid[_Mid]) {
                        _Hi = _Mid;
                    } else if (_Ch == _Table::_Wvalid[_Mid]) { // found the code, map it
                        _Ch = _Table::_Wtb[_Mid];
                        return true;
                    } else {
                        _Lo = _Mid + 1;
                    }
                }
                return false;
            }

        public:
            using _Mybase     = _STD codecvt<_Elem, char, _Statype>;
            using result      = typename _Mybase::result;
            using _Byte       = char;
            using intern_type = _Elem;
            using extern_type = _Byte;
            using state_type  = _Statype;

            explicit codecvt_jis(size_t _Refs = 0) : _Mybase(_Refs) {}

            ~codecvt_jis() noexcept override {}

        protected:
            result do_in(_Statype& _State, const _Byte* _First1, const _Byte* _Last1, const _Byte*& _Mid1,
                _Elem* _First2, _Elem* _Last2, _Elem*& _Mid2) const override {
                // convert bytes [_First1, _Last1) to [_First2, _Last)
                char* _Pstate = reinterpret_cast<char*>(&_State);
                result _Ans   = _Mybase::partial;

                _Mid1 = _First1;
                _Mid2 = _First2;
                while (_Mid1 != _Last1 && _Mid2 != _Last2) { // convert a multibyte sequence
                    unsigned char _By = static_cast<unsigned char>(*_Mid1);
                    unsigned short _Ch;

                    if (_By == _ESC_CODE) {
                        if (_Last1 - _Mid1 < 4) {
                            break;
                        }

                        if (_Mid1[1] == '(' && (_Mid1[2] == 'B' || _Mid1[2] == 'J')
                            && _Mid1[3] != _ESC_CODE) { // shift to 1-byte mode
                            _By      = static_cast<unsigned char>(*(_Mid1 += 3));
                            *_Pstate = 0;
                            continue;
                        }

                        if (_Mid1[1] == '$' && (_Mid1[2] == 'B' || _Mid1[2] == '@')) {
                            if (_Last1 - _Mid1 < 5) {
                                break; // not enough bytes for escape plus 2-byte code
                            }

                            // enough bytes, shift and convert
                            _By      = static_cast<unsigned char>(*(_Mid1 += 3));
                            *_Pstate = 1;
                            continue;
                        }

                        return _Mybase::error; // bad escape sequence
                    }

                    _Ch = _By;
                    if (*_Pstate != 0) {
                        if (_By < 0x21 || 0x7e < _By) {
                            return _Mybase::error; // bad first byte
                        }

                        if (_Last1 - _Mid1 < 2) {
                            break;
                        }

                        if ((_By = static_cast<unsigned char>(*++_Mid1)) < 0x21 || 0x7e < _By) {
                            return _Mybase::error; // bad second byte
                        }

                        _Ch = static_cast<unsigned short>(_Ch << 8 | _By);
                    }

                    if (!_Jis_to_ucs(_Ch) || _Maxcode < _Ch) {
                        return _Mybase::error; // code too large
                    }

                    ++_Mid1;
                    *_Mid2++ = static_cast<_Elem>(_Ch);
                    _Ans     = _Mybase::ok;
                }

                return _Ans;
            }

            result do_out(_Statype& _State, const _Elem* _First1, const _Elem* _Last1, const _Elem*& _Mid1,
                _Byte* _First2, _Byte* _Last2, _Byte*& _Mid2) const override {
                // convert [_First1, _Last1) to bytes [_First2, _Last)
                char* _Pstate = reinterpret_cast<char*>(&_State);

                _Mid1 = _First1;
                _Mid2 = _First2;
                while (_Mid1 != _Last1 && _Mid2 != _Last2) { // convert and put a wide char
                    unsigned long _Uch = static_cast<unsigned long>(*_Mid1);
                    unsigned short _Ch = static_cast<unsigned short>(_Uch);

                    if (_Maxcode < _Ch || !_Ucs_to_jis(_Ch)) {
                        return _Mybase::error;
                    }

                    if (_Ch <= 0xff) { // put a 1-byte code
                        if (_Ch == _ESC_CODE) {
                            return _Mybase::error; // can't output bald ESC
                        }

                        if (*_Pstate != 0) { // not in proper state
                            if (_Last2 - _Mid2 < 4) {
                                break;
                            }

                            // put shift to 1-byte state
                            *_Mid2++ = _ESC_CODE;
                            *_Mid2++ = '(';
                            *_Mid2++ = 'B';
                            *_Pstate = 0;
                        }

                        *_Mid2++ = static_cast<_Byte>(_Ch);
                    } else { // put a 2-byte code
                        if (*_Pstate != 0) { // already in proper state, check room
                            if (_Last2 - _Mid2 < 2) {
                                break;
                            }
                        } else if (_Last2 - _Mid2 < 5) {
                            break;
                        } else { // put shift to 2-byte state
                            *_Mid2++ = _ESC_CODE;
                            *_Mid2++ = '$';
                            *_Mid2++ = 'B';
                            *_Pstate = 1;
                        }

                        unsigned char _By = static_cast<unsigned char>(_Ch >> 8);

                        if (_By < 0x21 || 0x7e < _By) {
                            return _Mybase::error; // bad first byte
                        }

                        *_Mid2++ = static_cast<_Byte>(_By);

                        _By = static_cast<unsigned char>(_Ch);
                        if (_By < 0x21 || 0x7e < _By) {
                            return _Mybase::error; // bad first byte
                        }

                        *_Mid2++ = static_cast<_Byte>(_By);
                    }
                    ++_Mid1;
                }
                return _First1 == _Mid1 ? _Mybase::partial : _Mybase::ok;
            }

            result do_unshift(_Statype& _State, _Byte* _First2, _Byte*, _Byte*& _Mid2) const override {
                // generate bytes to return to default shift state
                char* _Pstate = reinterpret_cast<char*>(&_State);

                _Mid2 = _First2;
                switch (*_Pstate) {
                case 1: // need to home, put first of three bytes
                    *_Pstate = 2;
                    *_Mid2++ = _ESC_CODE;
                    break;

                case 2: // put second of three bytes
                    *_Pstate = 3;
                    *_Mid2++ = '(';
                    break;

                case 3: // put third of three bytes
                    *_Pstate = 0;
                    *_Mid2++ = 'B';
                    _FALLTHROUGH;

                case 0:
                    return _Mybase::ok;
                }
                return _Mybase::partial;
            }

            int do_length(
                _Statype& _State, const _Byte* _First1, const _Byte* _Last1, size_t _Count) const noexcept override {
                // return min(_Count, converted length of bytes [_First1, _Last1))
                size_t _Wchars    = 0;
                _Statype _Mystate = _State;

                while (_Wchars < _Count && _First1 != _Last1) { // convert another wide character
                    const _Byte* _Mid1;
                    _Elem* _Mid2;
                    _Elem _Ch;

                    // test result of single wide-char conversion
                    switch (do_in(_Mystate, _First1, _Last1, _Mid1, &_Ch, &_Ch + 1, _Mid2)) {
                    case _Mybase::noconv:
                        return static_cast<int>(_Wchars + (_Last1 - _First1));

                    case _Mybase::ok:
                        if (_Mid2 == &_Ch + 1) {
                            ++_Wchars; // replacement do_in might not convert one
                        }

                        _First1 = _Mid1;
                        break;

                    default:
                        return static_cast<int>(_Wchars); // error or partial
                    }
                }

                return static_cast<int>(_Wchars);
            }

            bool do_always_noconv() const noexcept override { // return true if conversions never change input
                return false;
            }

            int do_max_length() const noexcept override { // return maximum length required for a conversion
                return 5;
            }

            int do_encoding() const noexcept override { // return length of code sequence (from codecvt)
                return 0; // 0 => varying length
            }
        };
    } // namespace cvt
} // namespace stdext

#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
#pragma warning(pop)
#pragma pack(pop)

#endif // _STL_COMPILER_PREPROCESSOR
#endif // _CVT_XJIS_
