openocd_low_level_device.cc 7.33 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file openocd_low_level_device.cc
 */
#include <sstream>
#include <iomanip>

#include "micro_common.h"
#include "low_level_device.h"
#include "tcl_socket.h"

namespace tvm {
namespace runtime {

/*!
 * \brief OpenOCD low-level device for uTVM micro devices connected over JTAG
 */
class OpenOCDLowLevelDevice final : public LowLevelDevice {
 public:
  /*!
   * \brief constructor to initialize connection to openocd device
   * \param server_addr address of the OpenOCD server to connect to
   * \param port port of the OpenOCD server to connect to
   */
43
  explicit OpenOCDLowLevelDevice(const std::string& server_addr,
44
                                 int port) : socket_() {
45 46 47 48 49 50
    server_addr_ = server_addr;
    port_ = port;

    socket_.Connect(tvm::common::SockAddr(server_addr_.c_str(), port_));
    socket_.cmd_builder() << "halt 0";
    socket_.SendCommand();
51 52
  }

53
  void Read(DevPtr addr, void* buf, size_t num_bytes) {
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
    if (num_bytes == 0) {
      return;
    }

    // TODO(weberlo): Refactor between read and write.
    // Check if we need to chunk this write request.
    if (num_bytes > kMemTransferLimit) {
      char* curr_buf_ptr = reinterpret_cast<char*>(buf);
      while (num_bytes != 0) {
        size_t amount_to_read;
        if (num_bytes > kMemTransferLimit) {
          amount_to_read = kMemTransferLimit;
        } else {
          amount_to_read = num_bytes;
        }
69 70
        Read(addr, reinterpret_cast<void*>(curr_buf_ptr), amount_to_read);
        addr += amount_to_read;
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
        curr_buf_ptr += amount_to_read;
        num_bytes -= amount_to_read;
      }
      return;
    }
    {
      socket_.cmd_builder() << "array unset output";
      socket_.SendCommand();

      socket_.cmd_builder()
        << "mem2array output"
        << " " << std::dec << kWordSize
        << " " << addr.cast_to<void*>()
        // Round up any request sizes under a byte, since OpenOCD doesn't support
        // sub-byte-sized transfers.
        << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes);
      socket_.SendCommand();
    }

    {
      socket_.cmd_builder() << "ocd_echo $output";
      socket_.SendCommand();
      const std::string& reply = socket_.last_reply();

      std::istringstream values(reply);
      char* char_buf = reinterpret_cast<char*>(buf);
      ssize_t req_bytes_remaining = num_bytes;
      uint32_t index;
      uint32_t val;
      while (req_bytes_remaining > 0) {
        // The response from this command pairs indices with the contents of the
        // memory at that index.
        values >> index;
        CHECK(index < num_bytes)
          << "index " << index <<
          " out of bounds (length " << num_bytes << ")";
        // Read the value into `curr_val`, instead of reading directly into
        // `buf_iter`, because otherwise it's interpreted as the ASCII value and
        // not the integral value.
        values >> val;
        char_buf[index] = static_cast<uint8_t>(val);
        req_bytes_remaining--;
      }
      if (num_bytes >= 8) {
        uint32_t check_index;
        values >> check_index;
        CHECK(check_index != index) << "more data in response than requested";
      }
    }
  }

122
  void Write(DevPtr addr, const void* buf, size_t num_bytes) {
123 124 125 126 127 128 129 130 131 132 133 134 135 136
    if (num_bytes == 0) {
      return;
    }

    // Check if we need to chunk this write request.
    if (num_bytes > kMemTransferLimit) {
      const char* curr_buf_ptr = reinterpret_cast<const char*>(buf);
      while (num_bytes != 0) {
        size_t amount_to_write;
        if (num_bytes > kMemTransferLimit) {
          amount_to_write = kMemTransferLimit;
        } else {
          amount_to_write = num_bytes;
        }
137 138
        Write(addr, reinterpret_cast<const void*>(curr_buf_ptr), amount_to_write);
        addr += amount_to_write;
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
        curr_buf_ptr += amount_to_write;
        num_bytes -= amount_to_write;
      }
      return;
    }

    // Clear `input` array.
    socket_.cmd_builder() << "array unset input";
    socket_.SendCommand();
    // Build a command to set the value of `input`.
    {
      std::ostringstream& cmd_builder = socket_.cmd_builder();
      cmd_builder << "array set input {";
      const char* char_buf = reinterpret_cast<const char*>(buf);
      for (size_t i = 0; i < num_bytes; i++) {
        // In a Tcl `array set` commmand, we need to pair the array indices with
        // their values.
        cmd_builder << i << " ";
        // Need to cast to uint, so the number representation of `buf[i]` is
        // printed, and not the ASCII representation.
        cmd_builder << static_cast<uint32_t>(char_buf[i]) << " ";
      }
      cmd_builder << "}";
      socket_.SendCommand();
    }
    {
      socket_.cmd_builder()
        << "array2mem input"
        << " " << std::dec << kWordSize
        << " " << addr.cast_to<void*>()
        << " " << std::dec << num_bytes;
      socket_.SendCommand();
    }
  }

174
  void Execute(DevPtr func_addr, DevPtr breakpoint_addr) {
175 176 177 178
    socket_.cmd_builder() << "halt 0";
    socket_.SendCommand();

    // Set a breakpoint at the beginning of `UTVMDone`.
179
    socket_.cmd_builder() << "bp " << breakpoint_addr.cast_to<void*>() << " 2";
180 181 182 183 184 185 186 187 188 189 190 191
    socket_.SendCommand();

    socket_.cmd_builder() << "resume " << func_addr.cast_to<void*>();
    socket_.SendCommand();

    socket_.cmd_builder() << "wait_halt " << kWaitTime;
    socket_.SendCommand();

    socket_.cmd_builder() << "halt 0";
    socket_.SendCommand();

    // Remove the breakpoint.
192
    socket_.cmd_builder() << "rbp " << breakpoint_addr.cast_to<void*>();
193 194 195 196 197 198 199 200 201 202
    socket_.SendCommand();
  }

  const char* device_type() const final {
    return "openocd";
  }

 private:
  /*! \brief socket used to communicate with the device through Tcl */
  TclSocket socket_;
203 204 205 206
  /*! \brief address of OpenOCD server */
  std::string server_addr_;
  /*! \brief port of OpenOCD server */
  int port_;
207 208 209 210 211 212 213 214 215 216 217

  /*! \brief number of bytes in a word on the target device (64-bit) */
  static const constexpr ssize_t kWordSize = 8;
  // NOTE: OpenOCD will call any request larger than this constant an "absurd
  // request".
  /*! \brief maximum number of bytes allowed in a single memory transfer */
  static const constexpr ssize_t kMemTransferLimit = 64000;
  /*! \brief number of milliseconds to wait for function execution to halt */
  static const constexpr int kWaitTime = 10000;
};

218
const std::shared_ptr<LowLevelDevice> OpenOCDLowLevelDeviceCreate(const std::string& server_addr,
219 220
                                                                  int port) {
  std::shared_ptr<LowLevelDevice> lld =
221
      std::make_shared<OpenOCDLowLevelDevice>(server_addr, port);
222 223 224 225 226
  return lld;
}

}  // namespace runtime
}  // namespace tvm