// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "ortx_types.h"

const int API_VERSION = 1;

typedef enum {
  kOrtxKindUnknown = 0,

  kOrtxKindBegin = 0x7788,  // starting from a number to help validate the object
  kOrtxKindTokenizer = kOrtxKindBegin,
  kOrtxKindStringArray = 0x7789,
  kOrtxKindTokenId2DArray = 0x778A,
  kOrtxKindDetokenizerCache = 0x778B,
  kOrtxKindProcessor = 0x778C,
  kOrtxKindRawImages = 0x778D,
  kOrtxKindTensorResult = 0x778E,
  kOrtxKindProcessorResult = 0x778F,
  kOrtxKindTensor = 0x7790,
  kOrtxKindFeatureExtractor = 0x7791,
  kOrtxKindRawAudios = 0x7792,
  kOrtxKindString = 0x7793,
  kOrtxKindEnd = 0x7999
} extObjectKind_t;

// all object managed by the library should be 'derived' from this struct
// which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C
typedef struct {
  extObjectKind_t ext_kind_;
} OrtxObject;

typedef OrtxObject OrtxTensor;
typedef OrtxObject OrtxTensorResult;

// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
#define ORTX_DISPOSE(obj) OrtxDispose((OrtxObject**)&obj)
#define ORTX_RETURN_IF_ERROR(expr) \
  do {                             \
    auto _status = (expr);         \
    if (!_status.IsOk()) {         \
      return _status;              \
    }                              \
  } while (0)

typedef uint32_t extTokenId_t;

#ifdef __cplusplus
extern "C" {
#endif

/** \brief Get the current C ABI version of this library
 *
 * \snippet{doc} snippets.dox int Return Value
 */
int ORTX_API_CALL OrtxGetAPIVersion(void);

/** \brief Get the last error message generated by the library
 *
 * \param message Pointer to store the last error message
 * \return Pointer to the last error message
 */
const char* ORTX_API_CALL OrtxGetLastErrorMessage(void);

/** \brief Create a new object of the specified kind
 *
 * \param kind The kind of object to create
 * \param object Pointer to store the created object
 * \param ... Additional arguments based on the kind of object
 * \return Error code indicating the success or failure of the operation
 */
extError_t ORTX_API_CALL OrtxCreate(extObjectKind_t kind, OrtxObject** object, ...);

/** \brief Dispose the specified object
 *
 * \param object Pointer to the object to dispose
 * \return Error code indicating the success or failure of the operation
 */
extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);

/** \brief Dispose the specified object
 *
 * \param object Pointer to the object to dispose
 * \return Error code indicating the success or failure of the operation
 */
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);

/**
 * @brief Retrieves the tensor at the specified index from the given tensor result.
 *
 * This function allows you to access a specific tensor from a tensor result object.
 *
 * @param result The tensor result object from which to retrieve the tensor.
 * @param index The index of the tensor to retrieve.
 * @param tensor A pointer to a variable that will hold the retrieved tensor.
 * @return An error code indicating the success or failure of the operation.
 */
extError_t ORTX_API_CALL OrtxTensorResultGetAt(const OrtxTensorResult* result, size_t index, OrtxTensor** tensor);

/**
 * @brief Retrieves the data type of the given tensor.
 *
 * This function returns the data type of the specified tensor. The data type is
 * stored in the `type` parameter.
 *
 * @param tensor The tensor for which to retrieve the data type.
 * @param type   A pointer to a variable that will hold the retrieved data type.
 *
 * @return An `extError_t` value indicating the success or failure of the operation.
 */
extError_t ORTX_API_CALL OrtxGetTensorType(const OrtxTensor* tensor, extDataType_t* type);

/**
 * @brief Retrieves the size of each element in the given tensor.
 *
 * This function calculates the size of each element in the specified tensor and stores it in the provided size
 * variable.
 *
 * @param tensor A pointer to the OrtxTensor object.
 * @param size A pointer to a size_t variable to store the size of each element.
 * @return An extError_t value indicating the success or failure of the operation.
 */
extError_t ORTX_API_CALL OrtxGetTensorSizeOfElement(const OrtxTensor* tensor, size_t* size);

/** \brief Get the data from the tensor
 *
 * \param tensor The tensor object
 * \param data Pointer to store the data
 * \param shape Pointer to store the shape
 * \param num_dims Pointer to store the number of dimensions
 * \return Error code indicating the success or failure of the operation
 */
extError_t ORTX_API_CALL OrtxGetTensorData(const OrtxTensor* tensor, const void** data, const int64_t** shape,
                                           size_t* num_dims);

#ifdef __cplusplus
}
#endif
